import pybullet as p
import time
import os
import cv2
import numpy as np
import pybullet_data
import pdb
import importlib.util
import argparse
from env.fetch_controller import FetchController

class FetchSimulation:
    def __init__(self, gui=True, debug=False):

        # Connect to the simulation using CPU rendering when gui=False
        try:
            # Connect to physics server first
            if gui:
                p.connect(p.GUI)
                p.setRealTimeSimulation(1)
                
            else:
                # Always use CPU rendering (DIRECT mode) when gui is False
                p.connect(p.DIRECT)
                print("Using CPU rendering (DIRECT mode)")
                p.setRealTimeSimulation(0)
                p.setTimeStep(1/60)

            # Enable NumPy for better performance
            p.setPhysicsEngineParameter(enableFileCaching=0)
            p.setPhysicsEngineParameter(numSolverIterations=50)
            p.setPhysicsEngineParameter(enableConeFriction=0)
            p.setPhysicsEngineParameter(contactBreakingThreshold=0.001)
            p.setPhysicsEngineParameter(allowedCcdPenetration=0.0)
            p.setPhysicsEngineParameter(numSubSteps=2)
            p.setPhysicsEngineParameter(enableSAT=0)
            
            # Check if NumPy is being used
            numpy_enabled = p.isNumpyEnabled()
            print(f"PyBullet using NumPy: {numpy_enabled}")
            if not numpy_enabled:
                print("Warning: PyBullet was not compiled with NumPy support")
        except Exception as e:
            print(f"Warning: Could not initialize simulation with preferred settings: {e}")
            print("Falling back to basic CPU mode")
            if not p.isConnected():
                p.connect(p.DIRECT)
        # Set gravity
        p.setGravity(0, 0, -9.81)
        self.debug = debug
        self.initial_fetch_pos = [-2, 0, 0.1]  # Set Z offset to match wheel height from URDF
        current_dir = os.path.dirname(os.path.abspath(__file__))
        self.fetch_ros_path = os.path.join(current_dir, "..", "data", "fetch_ros")
        p.setAdditionalSearchPath(self.fetch_ros_path)
        self.plane_id = p.loadURDF("plane.urdf", basePosition=[0, 0, 0],useFixedBase=True)
        self._load_wall()
        self._load_road()
        self.ball_drop_penalized = False
        # Load the Fetch robot
        self.fetch_robot = None
        self.load_robots()
        self.start_positions = [0.0, 1.4, 0.0, -2.2, 0.0, 0.75, 1.55]
        self.start_torso_height = 0.2
        self.initial_orientation = p.getQuaternionFromEuler([0, 0, 0])
        self.initial_yaw = 0.0
        if self.fetch_robot is not None:
            self.fetch_controller = FetchController(self.fetch_robot, debug=self.debug)
            _, orientation = p.getBasePositionAndOrientation(self.fetch_robot)
            euler_angles = p.getEulerFromQuaternion(orientation)
            self.initial_yaw = ((euler_angles[2] + np.pi) % (2 * np.pi)) - np.pi
        else:
            print("Warning: Can't initialize controller because robot failed to load")
        self.plate_id = self.load_plate()
        self.attach_plate_to_gripper()
        self.ball_id = self.load_ball()
        self.load_table()

        self.current_step = 0
        self.max_episode_steps = 512
        self.end_reason = ""
        self.current_action = None  # Store the current action
        self.move_base = True
        self.has_rotated = False
        
        ############## These values might need to be changed #############
        self.goal_distance = 2.0  # 2 meters forward
        self.distance_threshold = 0.3
        self.goal_position = [0.2, 0, 0]  # Initial goal 2m in front of robot
        
    def load_robots(self):
        """Load the Fetch robot"""
        try:
            # Load the standard Fetch robot URDF
            fetch_urdf_path = os.path.join("fetch_description", "robots", "fetch.urdf")
            self.fetch_robot = p.loadURDF(fetch_urdf_path, 
                                        basePosition=self.initial_fetch_pos,
                                        useFixedBase=False)
            
            # Set the mass of all arm links to 1 to prevent tilt
            for i in range(p.getNumJoints(self.fetch_robot)):
                joint_info = p.getJointInfo(self.fetch_robot, i)
                link_name = joint_info[12].decode('utf-8')
                if "arm" in link_name or "shoulder" in link_name or "elbow" in link_name or "wrist" in link_name:
                    p.changeDynamics(self.fetch_robot, i, mass=0.5)
            
            return self.fetch_robot
                
        except Exception as e:
            print(f"Error loading robot: {e}")
            return None
     
    def reset(self):
        """Reset the simulation"""
        self.current_step = 0
        self.ball_drop_penalized = False

        p.resetBasePositionAndOrientation(
            self.fetch_robot, 
            self.initial_fetch_pos,
            self.initial_orientation
        )

        self.fetch_controller.set_arm_position(self.start_positions)

        # Allow the simulation to settle after resetting the arm
        for _ in range(10): # Step the simulation a few times
            p.stepSimulation()
            if p.getConnectionInfo()["connectionMethod"] == p.GUI:
                time.sleep(1./240.) # Optional small delay for GUI visualization

        # Initialize last state variables
        self.last_pos = np.array(self.initial_fetch_pos)
        self.last_distance_to_goal = self.goal_distance  
        _, orientation = p.getBasePositionAndOrientation(self.fetch_robot)
        euler_angles = p.getEulerFromQuaternion(orientation)
        self.last_yaw = ((euler_angles[2] + np.pi) % (2 * np.pi)) - np.pi

        # Get the plate position to place the ball directly above it
        if self.plate_id is not None:
            plate_pos, _ = p.getBasePositionAndOrientation(self.plate_id)
            ball_pos = [plate_pos[0], plate_pos[1], plate_pos[2] + 0.3]  # Place ball 30cm above plate
        else:
            ball_pos = [-1.1, 0, 1.15]  # Default position if plate not loaded

        # Reset ball position if it exists
        if hasattr(self, 'ball_id') and self.ball_id is not None:
            p.resetBasePositionAndOrientation(
                self.ball_id,
                posObj=ball_pos,
                ornObj=p.getQuaternionFromEuler([0, 0, 0])
            )
        else:
            # Load ball if it doesn't exist
            self.ball_id = self.load_ball()
            if self.ball_id is not None:
                p.resetBasePositionAndOrientation(
                    self.ball_id,
                    posObj=ball_pos,
                    ornObj=p.getQuaternionFromEuler([0, 0, 0])
                )
        
       
        # # Add a small delay to let the ball settle
        # for _ in range(400):
        #     p.stepSimulation()
        # Return initial state
        return self.get_state()
    
    def step(self, action):
        """Step the simulation"""
        p.stepSimulation()
        
        # Move the robot based on action
        self.fetch_controller.move_base_velocity(action[0], action[1])

        # Get current robot position
        robot_pos, _ = p.getBasePositionAndOrientation(self.fetch_robot)
        
        # Calculate distance to goal
        distance_to_goal = np.sqrt(
            (robot_pos[0] - self.goal_position[0])**2 + 
            (robot_pos[1] - self.goal_position[1])**2
        )
        
        # Get state (includes LIDAR update)
        state = self.get_state()
        reward = self.compute_reward(state, distance_to_goal)
        done = self.is_done(state, distance_to_goal)
        
        info = {
            'step': self.current_step, 
            'reason': self.end_reason,
            'closest_distance_to_goal': distance_to_goal,
            'robot_pos': robot_pos,
        }
        self.current_step += 1
    
        if self.debug:
            print(f"Distance to goal: {distance_to_goal}")
            print(f"Robot position: {robot_pos}")
            print(f"Goal position: {self.goal_position}")
        
        return state, reward, done, info
    
    def compute_reward(self, state, distance_to_goal):
        """Compute the reward encouraging linear forward movement while balancing."""
        total_reward = 0

        # Get current base position and orientation
        current_pos, orientation = p.getBasePositionAndOrientation(self.fetch_robot)
        current_pos = np.array(current_pos)
        euler_angles = p.getEulerFromQuaternion(orientation)
        current_yaw = ((euler_angles[2] + np.pi) % (2 * np.pi)) - np.pi  # Normalize yaw to [-π, π]

        # --- Ball Balancing Reward --- 
        # ball_detected = self.detect_red_sphere() 
        # if ball_detected:
        #     total_reward += 1  # Reward for keeping the ball on the plate
        # else:
        #     total_reward -= 1 # Major penalty for dropping the ball
        # ball_detected = self.detect_red_sphere() 
        # if not ball_detected:
        #     if not self.ball_drop_penalized:
        #         total_reward -= 10.0 # One-time large penalty for dropping the ball
        #         self.ball_drop_penalized = True

        # --- Forward Progress Reward --- 
        progress = max(0.0, self.last_distance_to_goal - distance_to_goal)
        progress_reward_scale = 200.0
        total_reward += progress * progress_reward_scale # Reward progress towards goal
        
        # --- Penalty for Rotation --- 
        # rotation_penalty_scale = 1.0
        # yaw_diff = abs(((current_yaw - self.initial_yaw + np.pi) % (2 * np.pi)) - np.pi)
        # total_reward -= yaw_diff * rotation_penalty_scale # Penalize deviation from initial orientation

        # # --- Goal Reached Reward --- 
        if distance_to_goal < self.distance_threshold:
            total_reward += 100.0 # Bonus for reaching the goal

        self.last_pos = current_pos
        self.last_distance_to_goal = distance_to_goal
        self.last_yaw = current_yaw

        return total_reward
    
     
    def is_done(self, state, distance_to_goal):
        """Check if episode is done"""
        Done = False

        #  Check if ball has fallen
        if self.ball_id is not None:
            ball_pos, _ = p.getBasePositionAndOrientation(self.ball_id)
            if ball_pos[2] < 0.7:  # Z-coordinate less than 0.7 meters
                Done = True
                self.end_reason = "ball_dropped"
                
        # if not self.detect_red_sphere():
        #     Done = True
        #     self.end_reason = "ball_dropped"
        
        # Check if robot is near the table
        if distance_to_goal < self.distance_threshold:
            Done = True
            self.end_reason = "success"
            
        # Check if max steps reached
        if self.current_step >= self.max_episode_steps:
            Done = True
            self.end_reason = "episode_limit"

        return Done
    
    def get_state(self):
        """
        Get the current state of the simulation, including the depth image and LIDAR data.
        
        Returns:
            dict: Dictionary containing:
                - 'depth_image': Depth image from the robot's camera
                - 'lidar_distances': Array of LIDAR distance measurements
                - 'lidar_intensities': Array of LIDAR intensity (RSSI) values
                - 'lidar_image_path': Path to saved LIDAR visualization image (if in debug mode)
        """
        _, depth_img, _ = self.fetch_controller.get_camera_image()
        #lidar_distances, lidar_intensities = self.fetch_controller.get_lidar_scan()
        
        # The LIDAR scan images are saved automatically in get_lidar_scan() when debug=True
        
        # state = {
        #     'depth_image': depth_img,
        #     'lidar_distances': lidar_distances,
        #     'lidar_intensities': lidar_intensities
        # }

        return depth_img


    def get_env_image(self):
        """
        Get the entire environment image
        """
        # Use the observation size defined in the simulator as the image resolution.
        width_px = 320
        height_px = 320

        # Compute view matrix for default view
        view_matrix = p.computeViewMatrixFromYawPitchRoll(
            cameraTargetPosition=[0, 0, 0.7],  # Target at robot's mid-height
            distance=2.5,                      # Distance from robot
            yaw=0,                             # 0 degrees for front view
            pitch=-30,                         # Slightly downward angle
            roll=0,                            # Keep horizon level
            upAxisIndex=2
        )

        # Define projection matrix
        proj_matrix = p.computeProjectionMatrixFOV(
            fov=60,                            # Standard field of view
            aspect=float(width_px) / float(height_px),
            nearVal=0.1,
            farVal=100.0
        )

        # Get camera image
        _, _, rgbImg, depthImg, segImg = p.getCameraImage(
            width_px, height_px,
            viewMatrix=view_matrix,
            projectionMatrix=proj_matrix
        )

        # Convert RGB array
        rgb_array = np.reshape(rgbImg, (height_px, width_px, 4))[:, :, :3]
        return rgb_array, depthImg, segImg
    
    def close(self):
        """Disconnect from the physics server"""
        if hasattr(self, 'physics_client') and p.isConnected(self.physics_client):
            p.disconnect(self.physics_client)
            print("Disconnected from physics server")
        elif p.isConnected():
            p.disconnect()
            print("Disconnected from physics server")

    def load_plate(self):
        """Load the plate from the data directory"""
        try:
            # Set the path to the plate URDF
            current_dir = os.path.dirname(os.path.abspath(__file__))
            plate_urdf_path = os.path.join(current_dir, "..", "data", "plate.urdf")
            
            # Check if the file exists
            if not os.path.exists(plate_urdf_path):
                print(f"Error: Plate URDF not found at {plate_urdf_path}")
                return None
                
            # Load the plate
            plate_id = p.loadURDF(
                plate_urdf_path,
                basePosition=[0.8, 0, 1],  # Initial position (will be updated when attached)
                baseOrientation=p.getQuaternionFromEuler([0, 0, 0]),
                useFixedBase=False,
                globalScaling=1.0  # Adjust scale if needed
            )
            
            p.changeDynamics(plate_id, -1, mass=0.01)

            return plate_id
            
        except Exception as e:
            print(f"Error loading plate: {e}")
            return None
    
    def load_ball(self):
        """Load the ball from the data directory"""
        try:
            # Set the path to the ball URDF
            current_dir = os.path.dirname(os.path.abspath(__file__))
            
            ball_urdf_path = os.path.join(current_dir, "..", "data", "ball.urdf")
            
            if not os.path.exists(ball_urdf_path):
                print(f"Error: Ball URDF not found at {ball_urdf_path}")
                return None

            # Get the plate position to place the ball directly above it
            if self.plate_id is not None:
                plate_pos, _ = p.getBasePositionAndOrientation(self.plate_id)
                ball_pos = [plate_pos[0], plate_pos[1], plate_pos[2] + 0.3]  # Place ball 10cm above plate
            else:
                ball_pos = [-1.1, 0, 1.15]  # Default position if plate not loaded

            ball_id = p.loadURDF(
                ball_urdf_path,
                basePosition=ball_pos,
                baseOrientation=p.getQuaternionFromEuler([0, 0, 0]),
                useFixedBase=False,
                globalScaling=1.0
            )


            return ball_id
        
        except Exception as e:
            print(f"Error loading ball: {e}")
            return None

    def attach_plate_to_gripper(self):
        """Attach the plate between the two gripper fingers"""
        if self.plate_id is None or self.fetch_robot is None:
            print("Cannot attach plate: plate or robot not loaded")
            return

        # Find necessary link and joint indices in one pass
        gripper_link_index = -1
        left_joint_index = -1
        right_joint_index = -1
        num_joints = p.getNumJoints(self.fetch_robot)

        for i in range(num_joints):
            joint_info = p.getJointInfo(self.fetch_robot, i)
            joint_name = joint_info[1].decode('utf-8')
            link_name = joint_info[12].decode('utf-8')

            if link_name == "gripper_link":
                gripper_link_index = i
            if joint_name == "l_gripper_finger_joint":
                left_joint_index = i
            elif joint_name == "r_gripper_finger_joint":
                right_joint_index = i

            # Optimization: break if all indices are found
            if gripper_link_index != -1 and left_joint_index != -1 and right_joint_index != -1:
                break

        # Check if essential indices were found
        if gripper_link_index == -1:
            print("Error: Could not find gripper_link")
            return
        if left_joint_index == -1 or right_joint_index == -1:
            print("Error: Could not find gripper finger joints")
            return

        # Create constraint
        plate_orientation = p.getQuaternionFromEuler([1.5708, 0, 0])
        self.plate_constraint = p.createConstraint(
            parentBodyUniqueId=self.fetch_robot,
            parentLinkIndex=gripper_link_index,
            childBodyUniqueId=self.plate_id,
            childLinkIndex=-1,
            jointType=p.JOINT_FIXED,
            jointAxis=[0, 0, 0],
            parentFramePosition=[0.23, 0, 0],
            childFramePosition=[0, 0, 0],
            parentFrameOrientation=[0, 0, 0, 1],
            childFrameOrientation=plate_orientation
        )

        # Disable collisions between all robot links and the plate
        for i in range(num_joints):
            p.setCollisionFilterPair(self.fetch_robot, self.plate_id, i, -1, enableCollision=0)
        # Disable collision between robot base and plate
        p.setCollisionFilterPair(self.fetch_robot, self.plate_id, -1, -1, enableCollision=0)

        # Set gripper finger positions
        finger_position = 0.015
        p.resetJointState(self.fetch_robot, left_joint_index, finger_position)
        p.resetJointState(self.fetch_robot, right_joint_index, finger_position)

        p.setJointMotorControl2(
            bodyUniqueId=self.fetch_robot,
            jointIndex=left_joint_index,
            controlMode=p.POSITION_CONTROL,
            targetPosition=finger_position,
            force=10.0  # Use a reasonable force
        )
        p.setJointMotorControl2(
            bodyUniqueId=self.fetch_robot,
            jointIndex=right_joint_index,
            controlMode=p.POSITION_CONTROL,
            targetPosition=finger_position,
            force=10.0  # Use a reasonable force
        )

    def detect_red_sphere(self):
        """
        Detects if there's a red object at the expected ball height above the plate
        using both depth image and color information.
        
        Returns:
            bool: True if red object detected at correct height, False otherwise
        """
        # Get both RGB and depth images
        rgb_img, depth_img, _ = self.fetch_controller.get_camera_image()
        
        # 1. Check depth
        expected_min_depth = 0.6
        expected_max_depth = 1.4
        depth_mask = (depth_img > expected_min_depth) & (depth_img < expected_max_depth)
        
        # 2. Check color
        hsv_image = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2HSV)
        
        # Define red color ranges in HSV
        lower_red1 = np.array([0, 120, 70])
        upper_red1 = np.array([10, 255, 255])
        lower_red2 = np.array([170, 120, 70])
        upper_red2 = np.array([180, 255, 255])
        
        # Create red color masks
        red_mask1 = cv2.inRange(hsv_image, lower_red1, upper_red1)
        red_mask2 = cv2.inRange(hsv_image, lower_red2, upper_red2)
        red_mask = red_mask1 | red_mask2
        
        # 3. Combine depth and color masks
        combined_mask = depth_mask & (red_mask > 0)
        
        # Count pixels that are both red AND at the correct depth
        object_pixels = np.sum(combined_mask)
        
        # Threshold for minimum number of pixels
        min_pixels_threshold = 60
        
        if self.debug:
            if object_pixels > min_pixels_threshold:
                print(f"Red ball detected at {object_pixels} pixels")
            else:
                print(f"No red ball detected")
                print(f"Object pixels: {object_pixels}")
        
        return object_pixels > min_pixels_threshold

    def load_table(self):
        """Load the table from the data directory"""
        try:
            # Set the path to the table URDF
            current_dir = os.path.dirname(os.path.abspath(__file__))
            table_urdf_path = os.path.join(current_dir, "..", "data", "table.urdf")
            
            table_position = [1, 0, 0]

            
            # Check if the file exists
            if not os.path.exists(table_urdf_path):
                print(f"Error: Table URDF not found at {table_urdf_path}")
                return None
                
            # Create quaternion for 90-degree rotation around Z-axis
            orientation = p.getQuaternionFromEuler([0, 0, np.pi/2])  # 90 degrees = π/2 radians
            
            # Load the table with the rotated orientation   
            table_id = p.loadURDF(
                table_urdf_path, 
                basePosition=table_position, 
                baseOrientation=orientation,
                useFixedBase=True
            )
            
            return table_id
        
        except Exception as e:
            print(f"Error loading table: {e}")
            return None 
        
    def _load_wall(self):
        boundary_height = 1.5  # Increased height of the boundary walls
        boundary_thickness = 0.01  # Thickness of the boundary walls
        boundary_color = [0, 1, 1, 1]  # White color
        boundary_extension = 1  # Extension for the walls
        road_length = 4.27
        # Define boundary wall positions (surrounding the entire intersection)
        boundary_configs = [
            # North boundary
            [0, road_length/2 + 0.5, boundary_height/2, 
            road_length + 2*boundary_thickness + boundary_extension, boundary_thickness, boundary_height],
            # South boundary
            [0, -road_length/2 - 0.5, boundary_height/2, 
            road_length + 2*boundary_thickness + boundary_extension, boundary_thickness, boundary_height],
            # East boundary
            [road_length/2 + 0.5, 0, boundary_height/2, 
            boundary_thickness, road_length + 2*boundary_thickness + boundary_extension, boundary_height],
            # West boundary
            [-road_length/2 - 0.5, 0, boundary_height/2, 
            boundary_thickness, road_length + 2*boundary_thickness + boundary_extension, boundary_height],
        ]
        
        # Create boundary walls
        for wall in boundary_configs:
            visual_shape = p.createVisualShape(
                shapeType=p.GEOM_BOX,
                halfExtents=[wall[3]/2, wall[4]/2, wall[5]/2],
                rgbaColor=boundary_color
            )
            collision_shape = p.createCollisionShape(
                shapeType=p.GEOM_BOX,
                halfExtents=[wall[3]/2, wall[4]/2, wall[5]/2]
            )
            p.createMultiBody(
                baseMass=0,  # Static object
                baseCollisionShapeIndex=collision_shape,
                baseVisualShapeIndex=visual_shape,
                basePosition=[wall[0], wall[1], wall[2]]
            )
    
    def _load_road(self):
        """Load the road from the data directory"""
        z_offset = 0.04
        road_width = 1.2
        road_length = 4.5   
        road_height = 0.05
        # Visual properties
        road_color = [0.1, 0.1, 0.1, 1]  # Dark gray
        line_color = [1, 1, 1, 1]        # White
        wall_color = [0, 0, 0, 0]        # Black
        
        ############################## Roads ##############################
        
        # Create the roads (North-South and East-West)
        road_shapes = [
            # East-West road
            [0, 0, z_offset, 0, 0, 0, road_length, road_width, road_height]
        ]
        
        # Create collision and visual shapes for roads
        for shape in road_shapes:
            visual_shape = p.createVisualShape(
                shapeType=p.GEOM_BOX,
                halfExtents=[shape[6]/2, shape[7]/2, shape[8]/2],
                rgbaColor=road_color
            )
            collision_shape = p.createCollisionShape(
                shapeType=p.GEOM_BOX,
                halfExtents=[shape[6]/2, shape[7]/2, shape[8]/2]
            )
            p.createMultiBody(
                baseMass=0,
                baseCollisionShapeIndex=collision_shape,
                baseVisualShapeIndex=visual_shape,
                basePosition=[shape[0], shape[1], shape[2]],
                baseOrientation=p.getQuaternionFromEuler([shape[3], shape[4], shape[5]])
            )
