import pybullet as p
import time
import os
import cv2
import numpy as np
import pybullet_data
import importlib.util
from env.controller import TurtlebotController
from env.intersection import create_real_intersection, create_real_intersection_walls
	
class TurtlebotSim:
	def __init__(self, gui, use_egl, ego_speed, tb2_speed, obs_size, action_size):
		# Check if GPU is available
		try:
			# Connect to physics server first
			if gui:
				p.connect(p.GUI)
			else:
				p.connect(p.DIRECT)

			# Enable NumPy for better performance
			p.setPhysicsEngineParameter(enableFileCaching=0)
			p.setPhysicsEngineParameter(numSolverIterations=50)
			p.setPhysicsEngineParameter(enableConeFriction=0)
			p.setPhysicsEngineParameter(contactBreakingThreshold=0.001)
			p.setPhysicsEngineParameter(allowedCcdPenetration=0.0)
			p.setPhysicsEngineParameter(numSubSteps=2)
			p.setPhysicsEngineParameter(enableSAT=0)
			
			# Check if NumPy is being used
			numpy_enabled = p.isNumpyEnabled()
			#print(f"PyBullet using NumPy: {numpy_enabled}")
			#if not numpy_enabled:
			#	print("Warning: PyBullet was not compiled with NumPy support")
			
			if use_egl:  # Only for headless mode
				egl = importlib.util.find_spec('eglRenderer')
				if egl:
					p.loadPlugin(egl.origin, "_eglRendererPlugin")
					#print("Loaded EGL plugin for GPU rendering")
					
			# Check GPU status
			#print("PyBullet Version:", p.getAPIVersion())
			#print("CUDA Available:", p.isNumpyEnabled())
			
		except Exception as e:
			print(f"Warning: Could not initialize GPU rendering: {e}")
			print("Falling back to CPU rendering")
			if not p.isConnected():
				p.connect(p.DIRECT)

		p.setGravity(0, 0, -9.81)

		p.setAdditionalSearchPath(os.path.join(os.path.dirname(__file__), "..", "data"))


		self.east_robot_speed = tb2_speed
		self.obs_size = obs_size
		self.action_size = action_size

		# Define road dimensions
		self.road_length = 4.2672
		self.road_width = 0.6096
		self.road_height = 0.0254
		self.load_robot_height = 0.0254
		self.road_length_offset = 0.3
		self.road_width_offset = 1
  
  
		self.max_episode_steps = 1000
	

		self.success_threshold = 0.75
		
		# Position TurtleBot at the end of the south lane, facing west (yaw = pi/2)
		self.offset = [
			self.road_width/2 - self.road_length_offset,  # x: moved further from right wall
			-self.road_length/2 + self.road_length_offset,  # y: moved up 3.0 units from bottom
			self.load_robot_height                        # z: same height
		]
		self.initial_orientation = p.getQuaternionFromEuler([0, 0, np.pi/2])  # facing west

		# Load second TurtleBot (east wall)
		self.east_position = [self.road_length/2 - self.road_length_offset, 
			self.road_width/2 - self.road_length_offset,  # moved further from north wall
			self.load_robot_height]
		self.east_orientation = [0,0,1,0]#p.getQuaternionFromEuler([0, 0, np.pi])  # Facing south

		# Load URDF files with error handling
		self.urdf_path = os.path.join(os.path.dirname(__file__), "..", "data", "turtlebot_new.urdf")
		self.urdf_path_east = os.path.join(os.path.dirname(__file__), "..", "data", "turtlebot_east_west.urdf")

		# Create intersection
		self.create_intersection(self.road_length, self.road_width, self.road_height)

		# Action Space
		self.action_space = {
            'linear': [-1.0, 1.0],   # Normalized bounds for linear velocity
            'angular': [-1.0, 1.0]   # Normalized bounds for angular velocity
        }
        
        # Define max speeds
		self.max_linear_speed = ego_speed  # Maximum linear velocity in m/s
		self.max_angular_speed = 1.0  # Maximum angular velocity in rad/s
		self.max_force = 100.0
        

		# Load robots with error handling
		try:
			# Load first TurtleBot (south lane)
			robot_id_1 = p.loadURDF(
				self.urdf_path,
				self.offset,
				self.initial_orientation
			)
			self.turtle = robot_id_1
			self.agent = TurtlebotController(robot_id_1, self.obs_size, self.action_size)
			
			# Load second TurtleBot (east lane)
			robot_id_2 = p.loadURDF(
				self.urdf_path_east,
				self.east_position,
				self.east_orientation
			)
			self.turtle_east = robot_id_2
			self.agent_east = TurtlebotController(robot_id_2, self.obs_size, self.action_size)
			
			# Store all agents in a dictionary
			self.robots = {
				'turtle_south': (robot_id_1, self.agent),
				'turtle_east': (robot_id_2, self.agent_east)
			}
			self.end_reason = ""
		except p.error as e:
			print(f"Error loading URDF files: {e}")
			p.disconnect()
			exit(1)

		p.setRealTimeSimulation(0)
		p.setTimeStep(1/30)
	
		#print("PyBullet Version:", p.getAPIVersion())
		#print("CUDA Available:", p.isNumpyEnabled())
  
		# Print joint information
		# for j in range(p.getNumJoints(robot_id_1)):
		# 	print(p.getJointInfo(robot_id_1, j))

		# Initialize step counter
		self.current_step = 0
		self.previous_distance = None

		# Add stop line tracking for east robot only
		self.east_robot_stopped = False
		self.east_stop_counter = 0
		self.stop_duration = 5
		self.safe_to_cross = False # True once east bot has crossed half way point of intersection
		self.cut_off = False
		self.passed_stop = False

		self.east_pos = [0,0,0]
		
	def create_intersection(self, road_length=10.0, road_width=2.0, road_height=0.1):
		create_real_intersection_walls(road_length, road_width, road_height)

	def get_state(self):
		"""Get current state using the agent interface"""
		goal = [0, self.road_length/2 - 0.65, 0]  # Target at north end of road
		state = self.agent.get_state(goal)  # This returns just the state array
		return state  # Return just the state array
	
	def is_done(self, state):
		"""Check if episode should end. The episode should end if the robot crashes or reaches the target."""
		# current_distance = state[6]

		# Check termination conditions
		done = False

		if state[1] > self.success_threshold:
			done = True
			self.end_reason = ("success")

		# Failure conditions
		if abs(state[0]) > self.road_width/2 or abs(state[1]) > self.road_length/2:  # Out of bounds
			done = True
			self.end_reason = ("out of bounds")
			#print("oob!")
   
		# Check for collisions with east robot only
		collision_east = p.getContactPoints(self.turtle, self.turtle_east)
		if len(collision_east) > 0:
			done = True
			self.end_reason = ("collision")
		# Episode length limit
		if self.current_step >= self.max_episode_steps: 
			done = True
			self.end_reason = ("episode limit reached")
   
		return done

	def reset(self):
		"""Reset the environment to initial state."""
		# Reset step counter
		self.current_step = 0
		
		# Reset stop line memory
		self.has_stopped_at_line = True
		self.has_penalized_no_stop = False
		self.has_reached_target = False
		self.safe_to_cross = False
		self.cut_off = False
		self.passed_stop = False
		self.east_at_origin = False
  
		if np.random.random()>0.5:
			self.east_rotate_done = True
		else:
			self.east_rotate_done = False
		#self.east_rotate_done=True

		# Reset first robot (south)
		p.resetBasePositionAndOrientation(
			self.turtle,
			[self.offset[0], self.offset[1] + (.5*np.random.random()), self.offset[2]],
			self.initial_orientation
		)
		p.resetBaseVelocity(self.turtle, [0, 0, 0], [0, 0, 0])
		
		# Reset second robot (east)
		p.resetBasePositionAndOrientation(
			self.turtle_east,
			[self.east_position[0] - np.random.random(),  self.east_position[1], self.east_position[2]],
			self.east_orientation
		)
		p.resetBaseVelocity(self.turtle_east, [0, 0, 0], [0, 0, 0])
		
		# Reset previous distance
		self.previous_distance = None
		self.east_pos = self.agent_east.get_pose()[0]
		
		# Reset stop line states
		self.east_robot_stopped = False
		self.east_stop_counter = np.random.randint(0, self.stop_duration)
		
		# Get initial state (just the state array, not a tuple)
		goal = [0, self.road_length/2 - 0.65, 0]  # Target at north end of road
		state = self.agent.get_state(goal)  # This returns just the state array
		
		return state  # Return just the state array, not a tuple
	
	def apply_action(self, action):
		"""
		Apply continuous action for robot control.
		Args:
			action: numpy array or list [linear_velocity, angular_velocity]
					Each value should be between -1 and 1
		"""
		# Clip actions to bounds
		linear_action = np.clip(action[0], self.action_space['linear'][0], self.action_space['linear'][1])
		angular_action = np.clip(action[1], self.action_space['angular'][0], self.action_space['angular'][1])

		# Scale actions to actual velocities
		linear_vel = linear_action * self.max_linear_speed
		angular_vel = angular_action * self.max_angular_speed

		# Use agent's move function instead of direct wheel control
		self.agent.move(linear_vel, angular_vel)

	def step(self, action):
		"""
		Execute one time step within the environment
		Args:
			action: [linear_velocity, angular_velocity] each in range [-1, 1]
		Returns:
			state: Current state observation
			reward: Reward for current step
			done: Whether episode has ended
			info: Additional information
		"""
		# Apply action
		self.apply_action(action)

		# Simulate one step
		p.stepSimulation()
		self.east_pos = self.agent_east.get_pose()[0]

		# Update east robot movement with stop line logic
		target_y = -self.road_length/2 + self.road_length_offset  # Target at south end of road
		current_pos_east = self.agent_east.get_pose()[0]  # Get position from controller
		#print(self.agent_east.get_pose())
		stop_line_y_east = 0.5
		# print("get turtlebot east pose", current_pos_east)
		# print("get stop line y", stop_line_y_east[0])
		# Check if east robot is at stop line

		# Add state tracking for rotation
		if not hasattr(self, 'east_rotate_done'):
			self.east_rotate_done = False
		if not hasattr(self, 'east_at_origin'):
			self.east_at_origin = False

		# Check if east robot is at stop line
		if abs(current_pos_east[0] - stop_line_y_east) < 0.1:  # Within 0.1m of stop line
			if not self.east_robot_stopped:
				self.east_robot_stopped = True
				self.east_stop_counter = 0
			
			if self.east_stop_counter < self.stop_duration:
				self.agent_east.stop()  # Stop
				self.east_stop_counter += 1
			else:
				# After stopping, move forward
				self.agent_east.move(self.east_robot_speed, 0)
		# Check if robot is near origin for rotation
		elif not self.east_at_origin and abs(current_pos_east[0]) <= 0.05:
			self.east_at_origin = True
			self.agent_east.stop()  # Stop briefly before rotation
		# After reaching origin, perform rotation if not done yet
		elif self.east_at_origin and not self.east_rotate_done:
			# Rotate 90 degrees (use angular velocity for rotation)
			# Use a higher angular velocity for a noticeable rotation
			self.agent_east.move(0, -1.0)  # Rotate in place
			
			# Get current orientation and check if rotation is complete
			_, orientation = self.agent_east.get_pose()
			euler_angles = p.getEulerFromQuaternion(orientation)
			target_yaw = np.pi/2  # 90 degrees rotation from initial south-facing orientation
			
			# Check if rotation is complete within a small threshold
			if abs(euler_angles[2] - target_yaw) < 0.1:
				self.east_rotate_done = True
		# After rotation is done, move forward again
		elif self.east_rotate_done:
			self.agent_east.move(self.east_robot_speed, 0)
		# Default behavior: continue moving forward
		else:
			self.agent_east.move(self.east_robot_speed, 0)

		self.current_step += 1

		# Get new state
		state = self.get_state()

		# Calculate reward
		reward = self.compute_reward(state)

		# Check if episode is done
		done = self.is_done(state)

		# Additional info
		info = {
			'distance_to_target': state[6],
			'step': self.current_step,
			'reason': self.end_reason,
			'stopped': self.has_stopped_at_line,
			"safe_cross": self.has_stopped_at_line and not self.cut_off and self.safe_to_cross,
			'stopped_east': self.east_robot_stopped,
		}

		return state, reward, done, info

	def step_old(self, action):
		"""
		Execute one time step within the environment
		Args:
			action: [linear_velocity, angular_velocity] each in range [-1, 1]
		Returns:
			state: Current state observation
			reward: Reward for current step
			done: Whether episode has ended
			info: Additional information
		"""
		# Apply action
		self.apply_action(action)

		# Simulate one step
		p.stepSimulation()
		self.east_pos = self.agent_east.get_pose()[0]

		# Update east robot movement with stop line logic
		target_y = -self.road_length/2 + self.road_length_offset  # Target at south end of road
		current_pos_east = self.agent_east.get_pose()[0]  # Get position from controller
		#print(self.agent_east.get_pose())
		stop_line_y_east = 0.5
		# print("get turtlebot east pose", current_pos_east)
		# print("get stop line y", stop_line_y_east[0])
		# Check if east robot is at stop line
		current_x = self.east_pos[0]
		target_x = self.east_position[0]  # Original starting x-position
		x_error = target_x - current_x
		angular_correction = np.clip(x_error * 0.5, -0.02, 0.02)  # correcting for drift which idk what caused it
		if abs(current_pos_east[0] - stop_line_y_east) < 0.1:  # Within 0.3m of stop line
			# PD control for maintaining x-position - further reduced gain and limits
			if not self.east_robot_stopped:
				self.east_robot_stopped = True
				self.east_stop_counter = 0
			
			if self.east_stop_counter < self.stop_duration:
				self.agent_east.move(0, 0)  # Stop
				self.east_stop_counter += 1
			else:
				if current_pos_east[1] > target_y:
					self.agent_east.move(self.east_robot_speed, angular_correction)
				else:
					self.agent_east.move(0, 0)
		else:
			if current_pos_east[1] > target_y:
				self.agent_east.move(self.east_robot_speed, angular_correction)
			else:
				self.agent_east.move(0, 0)

		self.current_step += 1

		# Get new state
		state = self.get_state()

		# Calculate reward
		reward = self.compute_reward(state)

		# Check if episode is done
		done = self.is_done(state)

		# Additional info
		info = {
			'distance_to_target': state[6],
			'step': self.current_step,
			'reason': self.end_reason,
			'stopped': self.has_stopped_at_line,
			"safe_cross": self.has_stopped_at_line and not self.cut_off and self.safe_to_cross,
			'stopped_east': self.east_robot_stopped,
		}

		return state, reward, done, info
	
	def compute_reward(self, state):
		"""Compute reward based on current state."""
		# Get current position from state
		current_x, current_y = state[0], state[1]
		east_x, east_y = self.east_pos[0], self.east_pos[1]
		linear_vel = state[4]  # Get linear velocity from state

		# Define target region (north lane)
		#print(current_x, current_y, self.success_threshold)
		current_distance = abs(current_y - self.success_threshold)

		# Initialize previous_distance if None
		if self.previous_distance is None:
			self.previous_distance = current_distance

		progress_reward = max((self.previous_distance - current_distance) * 20.0, 0)#*(not self.cut_off) # Amplify the progress
		
		# Stop line check with memory
		stop_line_y = -.5 #-self.road_length/2 + 1.5  # Position of stop line
		# print("stop line y", stop_line_y)
		stop_line_region = 0.1  # Region around stop line where stopping is required
		stop_speed_threshold = 0.03  # Speed threshold to consider the robot stopped
		
		# Initialize stop line and success memory if not exists
		if not hasattr(self, 'has_stopped_at_line'):
			self.has_stopped_at_line = False
		if not hasattr(self, 'has_reached_target'):
			self.has_reached_target = False
		if not hasattr(self, 'has_collided'):
			self.has_collided = False
		
		# Check if robot is in stop line region
		in_stop_region = abs(current_y - stop_line_y) < stop_line_region and current_y < stop_line_y
		#if in_stop_region:
		#	print("in stop region")
		is_stopped = abs(linear_vel) < stop_speed_threshold
		# Stop line reward logic
  
		stop_line_reward = 0
		if current_y > stop_line_y and self.east_stop_counter < self.stop_duration:
			stop_line_reward -= 5
			self.has_stopped_at_line = False
  
		# stop_line_reward = 0
		# if in_stop_region:
		# 	if is_stopped and not self.has_stopped_at_line:
		# 		stop_line_reward = 50.0  # Large one-time reward for stopping
		# 		self.has_stopped_at_line = True
		# 	elif not is_stopped and not self.has_stopped_at_line:
		# 		stop_line_reward = -1.0  # Continuous penalty while in region and not stopped
		# elif current_y > stop_line_y and not self.has_stopped_at_line:
		# 	stop_line_reward = -1.0  # Large penalty for passing without stopping
		#self.safe_to_cross = self.east_pos[0]<0 #safe to cross once east bot has passed half way through
		

		# Wall penalty
		wall_penalty = -1*int(abs(current_x) > .4)
		# if abs(current_x) > self.road_width/2 - 0.2 or abs(current_y) > self.road_length/2 - 0.2:
		# 	wall_penalty = -1
		wall_penalty = 0
		collision_penalty = 0
		collision_east = p.getContactPoints(self.turtle, self.turtle_east)
		if len(collision_east) > 0:
			collision_penalty -= 100.0  # Large penalty for colliding with other robots
			self.has_collided = True
		if np.sqrt((current_x-east_x)**2 + (current_y-east_y)**2) < .6:
			collision_penalty -= 5

		# Success reward - one-time bonus for reaching target
		success_reward = 0
		if state[1] > self.success_threshold and not self.has_reached_target:  # Within 1m of the north end
			success_reward = 200.0  # Large one-time reward for reaching target
			self.has_reached_target = True
		# Update previous distance
		self.previous_distance = current_distance

		total_reward = (
			progress_reward + 
			success_reward +
			stop_line_reward +
			wall_penalty +
			collision_penalty
		)

		return total_reward

	def process_keyboard_events(self):
		"""Delegate keyboard event processing to the agent controller."""
		self.agent.process_keyboard_events()

	def keyboard_teleop(self):
		"""Delegate keyboard teleop to the agent controller."""
		self.agent.keyboard_teleop()

	def get_camera_image(self):
		"""
		Get the camera image using the agent's camera function.
		"""
		return self.agent.get_camera_image()

	def get_env_image(self):
		"""
		Get the environment image from a birdseye view using PyBullet's camera functions.
		This function creates an overhead view by positioning the camera directly above the
		environment looking down onto it.
		"""
		# Use the observation size defined in the simulator as the image resolution.
		width_px = 320
		height_px = 320

		# Compute view matrix for birdseye view
		view_matrix = p.computeViewMatrixFromYawPitchRoll(
			cameraTargetPosition=[0, 0, 0],
			distance=9,
			yaw=0,
			pitch=-90,
			roll=0,
			upAxisIndex=2
		)

		# Define projection matrix
		proj_matrix = p.computeProjectionMatrixFOV(
			fov=95,
			aspect=float(width_px) / float(height_px),
			nearVal=0.1,
			farVal=10.0
		)

		# Get camera image
		_, _, rgbImg, depthImg, segImg = p.getCameraImage(
			width_px, height_px,
			viewMatrix=view_matrix,
			projectionMatrix=proj_matrix
		)

		# Convert RGB array
		rgb_array = np.reshape(rgbImg, (height_px, width_px, 4))[:, :, :3]
		return rgb_array
			
	def close(self):
		p.disconnect()