import pybullet as p
import time
import os
import cv2
import numpy as np
import pybullet_data
import importlib.util
from env.controller import TurtlebotController
from env.intersection import create_intersection
	
class TurtlebotSim:
	def __init__(self, gui, use_egl, tb2_speed, tb3_speed, obs_size, action_size):
		# Check if GPU is available
		try:
			# Connect to physics server first
			if gui:
				p.connect(p.GUI)
			else:
				p.connect(p.DIRECT)

			# Enable NumPy for better performance
			p.setPhysicsEngineParameter(enableFileCaching=0)
			p.setPhysicsEngineParameter(numSolverIterations=50)
			p.setPhysicsEngineParameter(enableConeFriction=0)
			p.setPhysicsEngineParameter(contactBreakingThreshold=0.001)
			p.setPhysicsEngineParameter(allowedCcdPenetration=0.0)
			p.setPhysicsEngineParameter(numSubSteps=2)
			p.setPhysicsEngineParameter(enableSAT=0)
			
			# Check if NumPy is being used
			numpy_enabled = p.isNumpyEnabled()
			print(f"PyBullet using NumPy: {numpy_enabled}")
			if not numpy_enabled:
				print("Warning: PyBullet was not compiled with NumPy support")
			
			if use_egl:  # Only for headless mode
				egl = importlib.util.find_spec('eglRenderer')
				if egl:
					p.loadPlugin(egl.origin, "_eglRendererPlugin")
					print("Loaded EGL plugin for GPU rendering")
					
			# Check GPU status
			print("PyBullet Version:", p.getAPIVersion())
			print("CUDA Available:", p.isNumpyEnabled())
			
		except Exception as e:
			print(f"Warning: Could not initialize GPU rendering: {e}")
			print("Falling back to CPU rendering")
			if not p.isConnected():
				p.connect(p.DIRECT)


		p.setGravity(0, 0, -9.81)

		p.setAdditionalSearchPath(os.path.join(os.path.dirname(__file__), "..", "data"))


		self.east_robot_speed = tb2_speed
		self.obs_size = obs_size
		self.action_size = action_size

		# Define road dimensions
		self.road_length = 10.0
		self.road_width = 2.0
		self.road_height = 0.1
		self.max_episode_steps = 1000

		self.success_threshold = 1.2
		
		# Position TurtleBot at the end of the south lane, facing west (yaw = pi/2)
		self.offset = [
			self.road_width/2 - 0.5,  # x: moved further from right wall
			-self.road_length/2 + 1.5,  # y: moved up 3.0 units from bottom
			0.1                        # z: same height
		]
		self.initial_orientation = p.getQuaternionFromEuler([0, 0, np.pi/2])  # facing west

		# Load second TurtleBot (east wall)
		self.east_position = [self.road_length/2 - 2.0, 
			self.road_width/2 - 0.8,  # moved further from north wall
			0.2]
		self.east_orientation = p.getQuaternionFromEuler([0, 0, np.pi])  # Facing south

		# Load URDF files with error handling
		self.urdf_path = os.path.join(os.path.dirname(__file__), "..", "data", "turtlebot.urdf")

		# Create intersection
		self.create_intersection()

		# Action Space
		self.action_space = {
            'linear': [-1.0, 1.0],   # Normalized bounds for linear velocity
            'angular': [-1.0, 1.0]   # Normalized bounds for angular velocity
        }
        
        # Define max speeds
		self.max_linear_speed = 0.65  # Maximum linear velocity in m/s
		self.max_angular_speed = 1.0  # Maximum angular velocity in rad/s
		self.max_force = 100.0
        

		# Load robots with error handling
		try:
			# Load first TurtleBot (south lane)
			robot_id_1 = p.loadURDF(
				self.urdf_path,
				self.offset,
				self.initial_orientation
			)
			self.turtle = robot_id_1
			self.agent = TurtlebotController(robot_id_1, self.obs_size, self.action_size)
			
			# Load second TurtleBot (east lane)
			robot_id_2 = p.loadURDF(
				self.urdf_path,
				self.east_position,
				self.east_orientation
			)
			self.turtle_east = robot_id_2
			self.agent_east = TurtlebotController(robot_id_2, self.obs_size, self.action_size)
			
			# Store all agents in a dictionary
			self.robots = {
				'turtle_south': (robot_id_1, self.agent),
				'turtle_east': (robot_id_2, self.agent_east)
			}
			self.end_reason = ""
		except p.error as e:
			print(f"Error loading URDF files: {e}")
			p.disconnect()
			exit(1)

		p.setRealTimeSimulation(0)
		p.setTimeStep(1/30)
	
		print("PyBullet Version:", p.getAPIVersion())
		print("CUDA Available:", p.isNumpyEnabled())
  
		# Print joint information
		# for j in range(p.getNumJoints(robot_id_1)):
		# 	print(p.getJointInfo(robot_id_1, j))

		# Initialize step counter
		self.current_step = 0
		self.previous_distance = None

		# Add stop line tracking for east robot only
		self.east_robot_stopped = False
		self.east_stop_counter = 0
		self.stop_duration = 5
		self.safe_to_cross = False # True once east bot has crossed half way point of intersection
		self.cut_off = False
		self.passed_stop = False

		self.east_pos = [0,0,0]
		

	def create_intersection(self, road_length=10.0, road_width=2.0, road_height=0.1):
		create_intersection(road_length, road_width, road_height)

	def get_state(self):
		"""Get current state using the agent interface"""
		goal = [0, self.road_length/2 - 0.5, 0]  # Target at north end of road
		state = self.agent.get_state(goal)  # This returns just the state array
		return state  # Return just the state array
	
	def is_done(self, state):
		"""Check if episode should end. The episode should end if the robot crashes or reaches the target."""
		# current_distance = state[6]

		# Check termination conditions
		done = False

		if state[1] > self.success_threshold:
			done = True
			self.end_reason = ("success")

		# Failure conditions
		if abs(state[0]) > self.road_width/2 or abs(state[1]) > self.road_length/2:  # Out of bounds
			done = True
			self.end_reason = ("out of bounds")
   
		# Check for collisions with east robot only
		collision_east = p.getContactPoints(self.turtle, self.turtle_east)
		if len(collision_east) > 0:
			done = True
			self.end_reason = ("collision")
		# Episode length limit
		if self.current_step >= self.max_episode_steps: 
			done = True
			self.end_reason = ("episode limit reached")
   
		return done

	def reset(self):
		"""Reset the environment to initial state."""
		# Reset step counter
		self.current_step = 0
		
		# Reset stop line memory
		self.has_stopped_at_line = False
		self.has_penalized_no_stop = False
		self.has_reached_target = False
		self.safe_to_cross = False
		self.cut_off = False
		self.passed_stop = False

		# Reset first robot (south)
		p.resetBasePositionAndOrientation(
			self.turtle,
			self.offset,
			self.initial_orientation
		)
		p.resetBaseVelocity(self.turtle, [0, 0, 0], [0, 0, 0])
		
		# Reset second robot (east)
		p.resetBasePositionAndOrientation(
			self.turtle_east,
			[self.east_position[0] + ((np.random.random()-.5)*2),  self.east_position[1], self.east_position[2]],
			self.east_orientation
		)
		p.resetBaseVelocity(self.turtle_east, [0, 0, 0], [0, 0, 0])
		
		# Reset previous distance
		self.previous_distance = None
		self.east_pos = self.agent_east.get_pose()[0]
		
		# Reset stop line states
		self.east_robot_stopped = False
		self.east_stop_counter = np.random.randint(0, self.stop_duration)
		
		# Get initial state (just the state array, not a tuple)
		goal = [0, self.road_length/2 - 0.5, 0]  # Target at north end of road
		state = self.agent.get_state(goal)  # This returns just the state array
		
		return state  # Return just the state array, not a tuple
	
	def apply_action(self, action):
		"""
		Apply continuous action for robot control.
		Args:
			action: numpy array or list [linear_velocity, angular_velocity]
					Each value should be between -1 and 1
		"""
		# Clip actions to bounds
		linear_action = np.clip(action[0], self.action_space['linear'][0], self.action_space['linear'][1])
		angular_action = np.clip(action[1], self.action_space['angular'][0], self.action_space['angular'][1])

		# Scale actions to actual velocities
		linear_vel = linear_action * self.max_linear_speed
		angular_vel = angular_action * self.max_angular_speed

		# Use agent's move function instead of direct wheel control
		self.agent.move(linear_vel, angular_vel)

	def step(self, action):
		"""
		Execute one time step within the environment
		Args:
			action: [linear_velocity, angular_velocity] each in range [-1, 1]
		Returns:
			state: Current state observation
			reward: Reward for current step
			done: Whether episode has ended
			info: Additional information
		"""
		# Apply action
		self.apply_action(action)

		# Simulate one step
		p.stepSimulation()
		self.east_pos = self.agent_east.get_pose()[0]

		# Update east robot movement with stop line logic
		target_y = -self.road_length/2 + 0.5  # Target at south end of road
		current_pos_east = self.agent_east.get_pose()[0]  # Get position from controller
		stop_line_y_east = 1.4
		# print("get turtlebot east pose", current_pos_east)
		# print("get stop line y", stop_line_y_east[0])
		# Check if east robot is at stop line
		current_x = self.east_pos[0]
		target_x = self.east_position[0]  # Original starting x-position
		
		# PD control for maintaining x-position - further reduced gain and limits
		x_error = target_x - current_x
		angular_correction = np.clip(x_error * 0.5, -0.03, 0.03)  # correcting for drift which idk what caused it
		if abs(current_pos_east[0] - stop_line_y_east) < 0.1:  # Within 0.3m of stop line
			

			if not self.east_robot_stopped:
				self.east_robot_stopped = True
				self.east_stop_counter = 0
			
			if self.east_stop_counter < self.stop_duration:
				self.agent_east.move(0, 0)  # Stop
				self.east_stop_counter += 1
			else:
				if current_pos_east[1] > target_y:
					self.agent_east.move(self.east_robot_speed, angular_correction)
				else:
					self.agent_east.move(0, 0)
		else:
			if current_pos_east[1] > target_y:
				self.agent_east.move(self.east_robot_speed, angular_correction)
			else:
				self.agent_east.move(0, 0)

		self.current_step += 1

		# Get new state
		state = self.get_state()

		# Calculate reward
		reward = self.compute_reward(state)

		# Check if episode is done
		done = self.is_done(state)

		# Additional info
		info = {
			'distance_to_target': state[6],
			'step': self.current_step,
			'reason': self.end_reason,
			'stopped': self.has_stopped_at_line,
			"safe_cross": self.has_stopped_at_line and not self.cut_off and self.safe_to_cross,
			'stopped_east': self.east_robot_stopped,
		}

		return state, reward, done, info
	

	def compute_reward(self, state):
		"""Compute reward based on current state."""
		# Get current position from state
		current_x, current_y = state[0], state[1]
		east_x, east_y = self.east_pos[0], self.east_pos[1]
		linear_vel = state[4]  # Get linear velocity from state

		# Define target region (north lane)
		current_distance = abs(current_y - self.success_threshold)

		# Initialize previous_distance if None
		if self.previous_distance is None:
			self.previous_distance = current_distance

		progress_reward = max((self.previous_distance - current_distance) * 10.0, 0)#*(not self.cut_off) # Amplify the progress
		
		# Stop line check with memory
		stop_line_y = -self.road_length/2 + 3.0  # Position of stop line
		stop_line_region = 0.4  # Region around stop line where stopping is required
		stop_speed_threshold = 0.065  # Speed threshold to consider the robot stopped
		
		# Initialize stop line and success memory if not exists
		if not hasattr(self, 'has_stopped_at_line'):
			self.has_stopped_at_line = False
		if not hasattr(self, 'has_reached_target'):
			self.has_reached_target = False
		if not hasattr(self, 'has_collided'):
			self.has_collided = False
		
		# Check if robot is in stop line region
		in_stop_region = abs(current_y - stop_line_y) < stop_line_region and current_y < stop_line_y
		is_stopped = abs(linear_vel) < stop_speed_threshold
		# Stop line reward logic
		stop_line_reward = 0
		if in_stop_region:
			if is_stopped and not self.has_stopped_at_line:
				stop_line_reward = 50.0  # Large one-time reward for stopping
				self.has_stopped_at_line = True
			elif not is_stopped and not self.has_stopped_at_line:
				stop_line_reward = -1.0  # Continuous penalty while in region and not stopped
		elif current_y > stop_line_y and not self.has_stopped_at_line:
			stop_line_reward = -1.0  # Large penalty for passing without stopping
		#self.safe_to_cross = self.east_pos[0]<0 #safe to cross once east bot has passed half way through
		
		# Stop line reward logic
		# stop_line_reward = 0
		# if in_stop_region:
		# 	if is_stopped and not self.has_stopped_at_line and self.safe_to_cross:
		# 		self.has_stopped_at_line = True
		# 		stop_line_reward += 50.0  # Large one-time reward for stopping

		# 	elif not is_stopped and not self.safe_to_cross:
		# 		stop_line_reward -= 1.0  # Continuous penalty while in region and not stopped
				
		# elif current_y > stop_line_y:
		# 	if not self.has_stopped_at_line and not self.passed_stop:
		# 		stop_line_reward -= 50  # Large penalty for passing without stopping
		# 		self.passed_stop = True

		# 	if not self.safe_to_cross and not self.cut_off:
		# 		stop_line_reward -= 50 # Large penalty for crossing while east bot is crossing
		# 		self.cut_off = True # ensure penalty is one time

		# Wall penalty
		wall_penalty = 0
		# if abs(current_x) > self.road_width/2 - 0.2 or abs(current_y) > self.road_length/2 - 0.2:
		# 	wall_penalty = -1

		collision_penalty = 0
		collision_east = p.getContactPoints(self.turtle, self.turtle_east)
		if len(collision_east) > 0:
			collision_penalty -= 100.0  # Large penalty for colliding with other robots
			self.has_collided = True
		if np.sqrt((current_x-east_x)**2 + (current_y-east_y)**2) < 1:
			collision_penalty -= 5

		# Success reward - one-time bonus for reaching target
		success_reward = 0
		if state[1] > self.success_threshold and not self.has_reached_target:  # Within 1m of the north end
			success_reward = 100.0  # Large one-time reward for reaching target
			self.has_reached_target = True
		# Update previous distance
		self.previous_distance = current_distance

		total_reward = (
			progress_reward + 
			success_reward +
			stop_line_reward +
			wall_penalty +
			collision_penalty
		)

		return total_reward

	def process_keyboard_events(self):
		"""Delegate keyboard event processing to the agent controller."""
		self.agent.process_keyboard_events()

	def keyboard_teleop(self):
		"""Delegate keyboard teleop to the agent controller."""
		self.agent.keyboard_teleop()

	def get_camera_image(self):
		"""
		Get the camera image using the agent's camera function.
		"""
		return self.agent.get_camera_image()

	def get_env_image(self):
		"""
		Get the environment image from a birdseye view using PyBullet's camera functions.
		This function creates an overhead view by positioning the camera directly above the
		environment looking down onto it.
		"""
		# Use the observation size defined in the simulator as the image resolution.
		width_px = 320
		height_px = 320

		# Compute view matrix for birdseye view
		view_matrix = p.computeViewMatrixFromYawPitchRoll(
			cameraTargetPosition=[0, 0, 0],
			distance=9,
			yaw=0,
			pitch=-90,
			roll=0,
			upAxisIndex=2
		)

		# Define projection matrix
		proj_matrix = p.computeProjectionMatrixFOV(
			fov=87,
			aspect=float(width_px) / float(height_px),
			nearVal=0.1,
			farVal=100.0
		)

		# Get camera image
		_, _, rgbImg, depthImg, segImg = p.getCameraImage(
			width_px, height_px,
			viewMatrix=view_matrix,
			projectionMatrix=proj_matrix
		)

		# Convert RGB array
		rgb_array = np.reshape(rgbImg, (height_px, width_px, 4))[:, :, :3]
		return rgb_array
	
	"""def create_intersection(self):
		# Road dimensions
		road_length = 10.0
		road_width = 2.0
		road_height = 0.1
		
		# Move intersection up by adjusting z position
		z_offset = 0.05
		
		# Visual properties
		road_color = [0.1, 0.1, 0.1, 1]  # Dark gray
		line_color = [1, 1, 1, 1]        # White
		wall_color = [0, 0, 0, 0]        # Black
		
		# Create the roads (North-South and East-West)
		road_shapes = [
			# North-South road
			[0, 0, z_offset, 0, 0, 0, road_width, road_length, road_height],
			# East-West road
			[0, 0, z_offset, 0, 0, 0, road_length, road_width, road_height]
		]
		
		# Create collision and visual shapes for roads
		for shape in road_shapes:
			visual_shape = p.createVisualShape(
				shapeType=p.GEOM_BOX,
				halfExtents=[shape[6]/2, shape[7]/2, shape[8]/2],
				rgbaColor=road_color
			)
			collision_shape = p.createCollisionShape(
				shapeType=p.GEOM_BOX,
				halfExtents=[shape[6]/2, shape[7]/2, shape[8]/2]
			)
			p.createMultiBody(
				baseMass=0,
				baseCollisionShapeIndex=collision_shape,
				baseVisualShapeIndex=visual_shape,
				basePosition=[shape[0], shape[1], shape[2]],
				baseOrientation=p.getQuaternionFromEuler([shape[3], shape[4], shape[5]])
			)

		# Add stop lines on all lanes
		stop_line_half_extents = [0.3, 0.05, 0.001] 
		stop_line_color = [1, 0, 0, 1]  # Red color for stop lines

		# Define stop line positions for each lane - adjust to match reward function
		stop_line_positions = [
			# North lane (horizontal line)
			[-0.5, self.road_length/2 - 3.8, z_offset + 0.1, 0, 0, 0],
			# South lane (horizontal line)
			[0.5, -self.road_length/2 + 3.8, z_offset + 0.1, 0, 0, 0],
			# East lane (vertical line)
			[self.road_length/2 - 3.8, 0.5, z_offset + 0.1, 0, 0, np.pi/2],
			# West lane (vertical line)
			[-self.road_length/2 + 3.8, -0.5, z_offset + 0.1, 0, 0, np.pi/2]
		]

		# Create stop lines
		for pos in stop_line_positions:
			stop_line_vis = p.createVisualShape(
				shapeType=p.GEOM_BOX,
				halfExtents=stop_line_half_extents,
				rgbaColor=stop_line_color
			)
			p.createMultiBody(
				baseMass=0,
				baseVisualShapeIndex=stop_line_vis,
				basePosition=pos[:3],
				baseOrientation=p.getQuaternionFromEuler(pos[3:])
			)

		# Add walls along the roads
		wall_height = 0.5  # Height of the walls
		wall_thickness = 0.1  # Thickness of the walls
		wall_length = 4.0  # Length of each wall segment
		wall_offset = 0.5  # Offset from the intersection
		
		# Define wall positions (along the roads)
		wall_configs = [
			# North road walls (2 walls)
			# Upper East wall
			[road_width/2, road_length/4 + wall_offset, wall_height/2, 
			 wall_thickness, wall_length, wall_height],
			# Upper West wall
			[-road_width/2, road_length/4 + wall_offset, wall_height/2, 
			 wall_thickness, wall_length, wall_height],
			 
			# South road walls (2 walls)
			# Lower East wall
			[road_width/2, -road_length/4 - wall_offset, wall_height/2, 
			 wall_thickness, wall_length, wall_height],
			# Lower West wall
			[-road_width/2, -road_length/4 - wall_offset, wall_height/2, 
			 wall_thickness, wall_length, wall_height],
			 
			# East road walls (2 walls)
			# Right Upper wall
			[road_length/4 + wall_offset, road_width/2, wall_height/2, 
			 wall_length, wall_thickness, wall_height],
			# Right Lower wall
			[road_length/4 + wall_offset, -road_width/2, wall_height/2, 
			 wall_length, wall_thickness, wall_height],
			 
			# West road walls (2 walls)
			# Left Upper wall
			[-road_length/4 - wall_offset, road_width/2, wall_height/2, 
			 wall_length, wall_thickness, wall_height],
			# Left Lower wall
			[-road_length/4 - wall_offset, -road_width/2, wall_height/2, 
			 wall_length, wall_thickness, wall_height],
		]
		
		# Create walls
		for wall in wall_configs:
			visual_shape = p.createVisualShape(
				shapeType=p.GEOM_BOX,
				halfExtents=[wall[3]/2, wall[4]/2, wall[5]/2],
				rgbaColor=wall_color
			)
			collision_shape = p.createCollisionShape(
				shapeType=p.GEOM_BOX,
				halfExtents=[wall[3]/2, wall[4]/2, wall[5]/2]
			)
			p.createMultiBody(
				baseMass=0,  # Static object
				baseCollisionShapeIndex=collision_shape,
				baseVisualShapeIndex=visual_shape,
				basePosition=[wall[0], wall[1], wall[2]]
			)

		# Add road lines
		line_width = 0.1
		line_length = road_length/4
		line_height = road_height + 0.001
		
		line_shapes = [
			# North center line
			[0, road_length/4, z_offset + 0.001, 0, 0, 0, line_width, line_length, line_height],
			# South center line
			[0, -road_length/4, z_offset + 0.001, 0, 0, 0, line_width, line_length, line_height],
			# East center line
			[road_length/4, 0, z_offset + 0.001, 0, 0, 0, line_length, line_width, line_height],
			# West center line
			[-road_length/4, 0, z_offset + 0.001, 0, 0, 0, line_length, line_width, line_height]
		]
		
		# Create lines
		for line in line_shapes:
			visual_shape = p.createVisualShape(
				shapeType=p.GEOM_BOX,
				halfExtents=[line[6]/2, line[7]/2, line[8]/2],
				rgbaColor=line_color
			)
			p.createMultiBody(
				baseMass=0,
				baseVisualShapeIndex=visual_shape,
				basePosition=[line[0], line[1], line[2]],
				baseOrientation=p.getQuaternionFromEuler([line[3], line[4], line[5]])
			)

		# Add stop signs at each lane
		stop_sign_height = 1.0  
		stop_sign_size = 0.3   
		pole_radius = 0.02    
		pole_height = 0.7    
		
		# Define stop sign positions for each lane
		stop_sign_positions = [
			# South lane (moved left)
			[-self.road_width/2 + 2, -self.road_length/2 + 4.0, stop_sign_height/4 + pole_height, 0, 0, 0],
			# North lane (moved left)
			[-self.road_width/2 - 0.3, self.road_length/2 - 4.0, stop_sign_height/4 + pole_height, 0, 0, np.pi],
			# East lane (moved left relative to direction of travel)
			[self.road_length/2 - 4.0, self.road_width/2 + 0.3, stop_sign_height/4 + pole_height, 0, 0, -np.pi/2],
			# West lane (moved left relative to direction of travel)
			[-self.road_length/2 + 4.0, -self.road_width/2 - 0.1 , stop_sign_height/4 + pole_height, 0, 0, np.pi/2]
		]
		
		# Create stop signs and poles
		for pos in stop_sign_positions:
			# Create the stop sign (red octagon)
			stop_sign_visual = p.createVisualShape(
				shapeType=p.GEOM_BOX,  
				halfExtents=[stop_sign_size/2, 0.02, stop_sign_size/2],
				rgbaColor=[0.8, 0, 0, 1]  
			)
			stop_sign_collision = p.createCollisionShape(
				shapeType=p.GEOM_BOX,
				halfExtents=[stop_sign_size/2, 0.02, stop_sign_size/2]
			)
			
			# Create the pole (gray cylinder)
			pole_visual = p.createVisualShape(
				shapeType=p.GEOM_CYLINDER,
				radius=pole_radius,
				length=pole_height,
				rgbaColor=[0.7, 0.7, 0.7, 1]  # Gray color
			)
			pole_collision = p.createCollisionShape(
				shapeType=p.GEOM_CYLINDER,
				radius=pole_radius,
				height=pole_height
			)
			
			# Create the stop sign and pole as separate bodies
			p.createMultiBody(
				baseMass=0,  # Static object
				baseCollisionShapeIndex=stop_sign_collision,
				baseVisualShapeIndex=stop_sign_visual,
				basePosition=pos[:3],
				baseOrientation=p.getQuaternionFromEuler(pos[3:])
			)
			
			# Position the pole below the sign
			pole_position = [pos[0], pos[1], pole_height/2]
			p.createMultiBody(
				baseMass=0,  # Static object
				baseCollisionShapeIndex=pole_collision,
				baseVisualShapeIndex=pole_visual,
				basePosition=pole_position,
				baseOrientation=p.getQuaternionFromEuler([0, 0, 0])
			)"""
			
	def close(self):
		p.disconnect()