import math 
import time
import sys
import traceback
import numpy as np
from simworld.agent.base_agent import BaseAgent
from simworld.utils.vector import Vector
from simworld.utils.logger import Logger
from simworld.traffic.base.traffic_signal import TrafficSignalState
from agent.action_space import Action, ActionSpace
from llm.prompt import VLM_SYSTEM_PROMPT, VLM_USER_PROMPT

class NavAgent(BaseAgent):
    _id_counter = 0
    _camera_id_counter = 1
    def __init__(self, position, direction, communicator, coarse_grained_map, fine_grained_map, llm, destination, rule_based, config, task_name):
        super().__init__(position, direction)
        self.id = NavAgent._id_counter
        NavAgent._id_counter += 1
        self.camera_id = NavAgent._camera_id_counter
        NavAgent._camera_id_counter += 1

        self.config = config
        self.max_steps = 30  # Add default max steps limit

        self.communicator = communicator
        self.destination = destination
        self.coarse_grained_map = coarse_grained_map
        self.fine_grained_map = fine_grained_map
        self.llm = llm
        self.rule_based = rule_based

        self.path = self.get_path_positions()

        self.minimum_steps = self.calculate_minimum_steps()
        # print(f"Agent {self.id} has {self.minimum_steps} minimum steps")

        self.task_name = task_name
        self.collision_count = 0
        self.human_collision_count = 0
        self.object_collision_count = 0
        self.success = False
        self.elapsed_time = 0
        self.red_light_violation = False

        self.num_steps = 0
        self.num_step_forward = 0
        self.num_turn_around = 0
        self.num_choose_waypoint = 0
        self.num_error_steps = 0
        
        self.llm_call_time_sum = 0
        self.llm_call_time = 0
        self.num_valid_choose_waypoint = 0

        self.get_stuck = False
        self.step_limit_reached = False  # Add new flag for step limit

        self.action_history = []
        self.last_image = None
        self.history_window = self.config['pysbench.history_window']
        self.last_position = None  # use to check collision for vision based

        self.logger = Logger().get_logger('NavAgent')

    def run(self, exit_event):
        self.logger.info(f"Agent {self.id} is running, max steps: {self.max_steps}")
        try:
            start_time = time.time()
            if self.rule_based:
                for point in self.path[1:]:
                    self.navigate_rule_based(point, exit_event)
                human_collision, object_collision = self.communicator.get_collision_number(self.id)
                self.collision_count += human_collision + object_collision
                self.human_collision_count += human_collision
                self.object_collision_count += object_collision
                self.elapsed_time += time.time() - start_time
                if exit_event is not None and exit_event.is_set():
                    self.success = False
                else:
                    self.success = True
                self.logger.info(f"Agent {self.id} has collided {self.collision_count} times")
                return
            else:
                for point in self.path[1:]:
                    self.navigate_vision_based(point, exit_event)
                    # Check if step limit was reached during navigation
                    if self.step_limit_reached:
                        self.success = False
                        self.elapsed_time += time.time() - start_time
                        self.logger.info(f"Agent {self.id} has collided {self.collision_count} times")
                        return
                        
                self.elapsed_time += time.time() - start_time
                if exit_event is not None and exit_event.is_set():
                    self.success = False
                else:
                    self.success = True
                self.logger.info(f"Agent {self.id} has collided {self.collision_count} times")
                return
        except Exception as e:
            self.logger.error(f"Error in agent {self.id}: {e}")
            print(traceback.format_exc())

    def navigate_rule_based(self, point, exit_event):
        self.logger.info(f"Agent {self.id} is navigating to {point}, current position: {self.position}, rule based mode")
        if self.fine_grained_map.traffic_signals:
            current_node = self.fine_grained_map.get_closest_node(self.position)
            if current_node.type == 'intersection':
                traffic_light = None
                min_distance = self.config['pysbench.sidewalk_offset'] * 2
                for signal in self.fine_grained_map.traffic_signals:
                    distance = self.position.distance(signal.position)
                    if distance < min_distance:
                        min_distance = distance
                        traffic_light = signal

                if traffic_light is not None:
                    while not exit_event.is_set():
                        state = traffic_light.get_state()
                        left_time = traffic_light.get_left_time()
                        # print(f"Traffic light state: {state}, left time: {left_time}")
                        if state[1] == TrafficSignalState.PEDESTRIAN_GREEN and left_time > min(15, self.config['traffic.traffic_signal.pedestrian_green_light_duration']):
                            break
                        time.sleep(self.config['simworld.dt'])

        self.communicator.agent_move_forward(self.id)
        while not self._arrive_at_waypoint(point) and (exit_event is None or not exit_event.is_set()):
            while not self._align_direction(point) and (exit_event is None or not exit_event.is_set()):
                self.communicator.agent_stop(self.id)
                angle, turn = self._get_angle_and_direction(point)
                self.communicator.agent_rotate(self.id, angle, turn)
                time.sleep(self.config['simworld.dt'])
            self.communicator.agent_move_forward(self.id)
            time.sleep(self.config['simworld.dt'])
        self.communicator.agent_stop(self.id)

    def navigate_vision_based(self, point, exit_event):
        self.logger.info(f"Agent {self.id} is navigating to {point}, current position: {self.position}, vision based mode")

        possible_next_waypoints = self.get_possible_next_waypoints(point)
        human_collision, object_collision = 0, 0
        while not self._arrive_at_waypoint(point) and (exit_event is None or not exit_event.is_set()):
            # Check if step limit is reached
            if self.num_steps >= self.max_steps:
                self.logger.info(f"Agent {self.id} reached maximum step limit of {self.max_steps}")
                self.step_limit_reached = True
                break

            planned_node = self.fine_grained_map.get_closest_node(point)
            time.sleep(1)
            images = []
            image = self.communicator.get_camera_observation(self.camera_id, 'lit')
            # self.communicator.unrealcv.show_img(image)
            # return

            if self.last_image is not None:
                images.append(self.last_image)
            else:
                images.append(image)
            images.append(image)
            self.last_image = image

            # Distance calculation remains the same as it uses magnitude
            relative_distance = self.position.distance(point)
            
            # Calculate relative direction considering UE coordinate system
            # Convert yaw to radians - UE yaw is clockwise from X axis
            current_yaw_rad = math.radians(self.yaw)
            
            # Calculate target angle in UE coordinates
            dx = point.x - self.position.x
            dy = point.y - self.position.y
            target_yaw_rad = math.atan2(dy, dx)
            
            # Calculate relative angle
            # Normalize the difference to [-π, π]
            angle = math.degrees(target_yaw_rad - current_yaw_rad)
            if angle > 180:
                angle -= 360
            elif angle < -180:
                angle += 360


            action_str = f"I was at {self.position} and I want to go to {point}. The relative distance is {relative_distance} cm and the relative angle is {angle} degrees. After I made the decision, "

            system_prompt = VLM_SYSTEM_PROMPT
            user_prompt = VLM_USER_PROMPT.format(
                current_position=self.position,
                current_direction=self.direction,
                target_position=self.destination,
                next_waypoint=point,
                possible_next_waypoints=possible_next_waypoints,
                relative_distance=relative_distance,
                relative_angle=angle,
                action_history=self.action_history
            )

            if self.config['pysbench.agent_type'] == 'openai':
                response, call_time = self.llm.generate_nav_instructions_openai(images, system_prompt, user_prompt)
            elif self.config['pysbench.agent_type'] == 'openrouter':
                response, call_time = self.llm.generate_nav_instructions_openrouter(images, system_prompt, user_prompt)


            # print('response', response)

            if response is None:
                print('response is None')
                self.num_error_steps += 1
                continue

            self.num_steps += 1
            self.llm_call_time_sum += call_time
            if self.num_steps > 0:
                self.llm_call_time = self.llm_call_time_sum / self.num_steps
            else:
                self.llm_call_time = 0
            vlm_action = ActionSpace.from_json(response)
            self.logger.info(f"Agent {self.id} is taking action {vlm_action}")

            if self.fine_grained_map.traffic_signals:
                current_node = self.fine_grained_map.get_closest_node(self.position)
                if current_node.type == 'intersection' and planned_node.type == 'crosswalk':
                    # if the agent is at the crosswalk, then check the traffic light
                    traffic_light = None
                    min_distance = self.config['pysbench.sidewalk_offset'] * 2
                    for signal in self.fine_grained_map.traffic_signals:
                        distance = self.position.distance(signal.position)
                        if distance < min_distance:
                            min_distance = distance
                            traffic_light = signal

                    if traffic_light is not None:
                        state = traffic_light.get_state()
                        if state[1] == TrafficSignalState.PEDESTRIAN_RED and vlm_action.choice == Action.STEP_FORWARD:
                            self.red_light_violation = True



            if vlm_action.choice == Action.STEP_FORWARD:
                self.num_step_forward += 1
                self.communicator.agent_step_forward(self.id, vlm_action.duration, vlm_action.direction)
                if vlm_action.direction == 0:
                    action_str += f"I chose to step forward for {vlm_action.duration} seconds."
                else:
                    action_str += f"I chose to step backward for {vlm_action.duration} seconds."

                _human_collision, _object_collision = self.communicator.get_collision_number(self.id)
                
                # Only increment collision count if position has changed
                if self.last_position is None or self.position.distance(self.last_position) > 400:
                    self.human_collision_count += _human_collision
                    self.object_collision_count += _object_collision
                    self.collision_count += _human_collision + _object_collision
                
                self.last_position = Vector(self.position.x, self.position.y)  # Create new Vector instance

                if _human_collision > 0 or _object_collision > 0:
                    action_str += f"But I have collided with something."

            elif vlm_action.choice == Action.TURN_AROUND:
                self.num_turn_around += 1
                clockwise = 'right' if vlm_action.clockwise else 'left'
                action_str += f"I chose to turn {clockwise} {vlm_action.angle} degrees."
                if vlm_action.angle < 10:
                    continue
                self.communicator.agent_rotate(self.id, vlm_action.angle, clockwise)
            elif vlm_action.choice == Action.CHOOSE_WAYPOINT:
                self.num_choose_waypoint += 1
                point = vlm_action.new_waypoint
                if point in possible_next_waypoints and self.fine_grained_map.get_closest_node(point).obstacle == False:
                    self.num_valid_choose_waypoint += 1
                action_str += f"I chose to go to {point}."
            elif vlm_action.choice == Action.DO_NOTHING:
                action_str += f"I chose to do nothing."

            self.action_history.append(action_str)
            if len(self.action_history) > self.history_window:
                self.action_history.pop(0)



    def _arrive_at_waypoint(self, waypoint):
        threshold = self.config['pysbench.waypoint_distance_threshold']
        
        distance = self.position.distance(waypoint)
        
        if distance < threshold:
            return True
            
        # to_waypoint = waypoint - self.position
        # dot_product = self.direction.dot(to_waypoint.normalize())
        # if dot_product < -0.5:  # 120 degrees
        #     return True
            
        return False
    
    def _get_angle_and_direction(
        self,
        waypoint: Vector,
    ):
        """Compute angle and turn direction to face the waypoint."""
        to_wp = waypoint - self.position
        angle = math.degrees(
            math.acos(np.clip(self.direction.dot(to_wp.normalize()), -1, 1))
        )
        cross = self.direction.cross(to_wp)
        turn_direction = 'left' if cross < 0 else 'right'
        if angle < 2:
            return 0.0, 'left'
        return angle, turn_direction

    def _align_direction(self, waypoint: Vector) -> bool:
        """Return True if facing the waypoint within a small angle."""
        to_wp = waypoint - self.position
        angle = math.degrees(
            math.acos(np.clip(self.direction.dot(to_wp.normalize()), -1, 1))
        )
        return angle < 10

    def get_path_positions(self):
        self.position = self.fine_grained_map.get_closest_non_obstacle_node(self.position).position
        self.destination = self.fine_grained_map.get_closest_non_obstacle_node(self.destination).position
        
        path = self.fine_grained_map.get_route(
            self.fine_grained_map.get_closest_node(self.position),
            self.fine_grained_map.get_closest_node(self.destination)
        )

        positions = [node.position for node in path]
        return positions


    def get_possible_next_waypoints(self, planned_next_waypoint: Vector):
        current_node = self.fine_grained_map.get_closest_node(self.position)
        possible_next_waypoints = self.fine_grained_map.adjacency_list[current_node]

        planned_dir = (planned_next_waypoint - self.position).normalize()

        filtered_waypoints = []
        for node in possible_next_waypoints:
            candidate_dir = (node.position - self.position).normalize()
            dot = planned_dir.dot(candidate_dir)
            if dot > 0.707:
                filtered_waypoints.append(node.position)
        return filtered_waypoints
    
    def calculate_minimum_steps(self):
        # calculate the minimum steps to reach the destination
        # best_path = self.fine_grained_map.get_route(
        #     self.fine_grained_map.get_closest_node(self.position),
        #     self.fine_grained_map.get_closest_node(self.destination),
        #     avoid_obstacle=True
        # )

        # print('length of best path', len(best_path))
        # print('length of path', len(self.path))
        # assert len(best_path) == len(self.path), "The length of the best path is not equal to the path"
        
        # best_positions = [node.position for node in best_path]
        # total_steps = len(self.path) 
        
        
        # for i in range(1, len(self.path) - 1):
        #     if self.path[i] != best_positions[i]:
        #         total_steps += 2 
                
        # return total_steps
        return len(self.path) - 1
        
