import random
import json
import os
import time
import traceback
from simworld.utils.vector import Vector
from simworld.traffic.base.traffic_signal import TrafficSignalState
from map.map import Map
from map.world_generator import WorldGenerator
from simworld.utils.logger import Logger
from agent.nav_agent import NavAgent
from llm.nav_llm import NavLLM
from concurrent.futures import ThreadPoolExecutor
from threading import Event

class TaskManager:
    def __init__(self, config, communicator, traffic_controller = None):
        random.seed(config['simworld.seed'])

        self.config = config
        self.communicator = communicator
        self.world_generator = WorldGenerator(self.config)
        self.traffic_controller = traffic_controller
        self.dt = self.config['simworld.dt']

        self.exit_event = Event()
        
        # Add position tracking for stuck detection
        self.position_history = {}  # {agent_id: [(time, position), ...]}
        self.check_interval = 2.0  # Check every second
        self.logger = Logger().get_logger('PhysicalTaskManager')

        self.coarse_grained_map, self.fine_grained_map, self.agents = self.initialize()


    def initialize(self):
        if self.traffic_controller is None:
            traffic_signals = None
        else:
            traffic_signals = self.traffic_controller.traffic_signals

        # Initialize maps
        fine_grained_map = Map(self.config, traffic_signals)
        fine_grained_map.initialize_map_from_file(self.config['pysbench.input_roads'], fine_grained=True)
        self.world_generator.generate_world(fine_grained_map, self.config['pysbench.input_world_path'], self.config['pysbench.output_world_path'])
        coarse_grained_map = Map(self.config, traffic_signals)
        coarse_grained_map.initialize_map_from_file(self.config['pysbench.input_roads'], fine_grained=False)

        # Initialize agents
        agents = self.initialize_agents(coarse_grained_map, fine_grained_map, self.config['pysbench.task_level'])
        return coarse_grained_map, fine_grained_map, agents
    
    def initialize_agents(self, coarse_grained_map, fine_grained_map, task_level=1):
        with open(self.config['pysbench.task_file'], 'r') as f:
            task = json.load(f)
        
        self.logger.info(f"Task file: {self.config['pysbench.task_file']}")

        start_task = (task_level - 1) * 10 + 1
        end_task = task_level * 10
        self.logger.info(f"Initializing agents for tasks {start_task} to {end_task}")


        agents = []
        with open(self.config['pysbench.agent_file'], 'r') as f:
            agent_config = json.load(f)[self.config['pysbench.agent_type']]

        self.model_name = agent_config['model'].split('/')[-1]
        

        for i, task_idx in enumerate(range(start_task, end_task + 1)):
            task_name = f"task{task_idx}"
            if task_name not in task:
                raise ValueError(f"Task {task_name} not found in {self.config['pysbench.task_file']}")
            selected_task = task[task_name]

            origin_position = Vector(selected_task['origin']['x'], selected_task['origin']['y'])
            destination_position = Vector(selected_task['destination']['x'], selected_task['destination']['y'])

            if agent_config['provider'] == 'openai':
                llm = NavLLM(agent_config['model'], agent_config['url'], os.getenv('OPENAI_API_KEY'))
            elif agent_config['provider'] == 'openrouter':
                llm = NavLLM(agent_config['model'], agent_config['url'], os.getenv('OPENROUTER_API_KEY'))

            agent = NavAgent(
                origin_position, Vector(random.uniform(-1, 1), random.uniform(-1, 1)).normalize(), self.communicator, coarse_grained_map, fine_grained_map,
                llm, destination_position, self.config['pysbench.rule_based'], self.config, task_name
            )
            agents.append(agent)

        return agents

    def spawn_world(self, world_file):
        self.communicator.generate_world(world_file, 'data/ue_assets.json')

    def spawn_agents(self):
        for agent in self.agents:
            self.communicator.spawn_agent(agent, self.config['pysbench.model_path'])
            self.communicator.agent_set_speed(agent.id, self.config['pysbench.speed'])

    def update_physical_states(self):
        if self.traffic_controller is not None:
            vehicle_ids = [vehicle.id for vehicle in self.traffic_controller.vehicles]
            pedestrian_ids = [pedestrian.id for pedestrian in self.traffic_controller.pedestrians]
            traffic_signal_ids = [signal.id for signal in self.traffic_controller.traffic_signals]
        else:
            vehicle_ids = []
            pedestrian_ids = []
            traffic_signal_ids = []

        agent_ids = [agent.id for agent in self.agents]

        result = self.communicator.get_position_and_direction(vehicle_ids, pedestrian_ids, traffic_signal_ids, agent_ids)
        # print(result)
        for (type, object_id), values in result.items():
            if type == 'vehicle':
                position, direction = values
                self.traffic_controller.vehicles[object_id].position = position
                self.traffic_controller.vehicles[object_id].direction = direction
            elif type == 'pedestrian':
                position, direction = values
                self.traffic_controller.pedestrians[object_id].position = position
                self.traffic_controller.pedestrians[object_id].direction = direction
            elif type == 'traffic_signal':
                is_vehicle_green, is_pedestrian_walk, left_time = values
                for signal in self.traffic_controller.traffic_signals:
                    if signal.id == object_id:
                        if is_vehicle_green:
                            signal.set_state((TrafficSignalState.VEHICLE_GREEN, TrafficSignalState.PEDESTRIAN_RED))
                        elif is_pedestrian_walk:
                            signal.set_state((TrafficSignalState.VEHICLE_RED, TrafficSignalState.PEDESTRIAN_GREEN))
                        else:
                            signal.set_state((TrafficSignalState.VEHICLE_RED, TrafficSignalState.PEDESTRIAN_RED))
                        signal.set_left_time(left_time)
                        break
            elif type == 'agent':
                position, direction = values
                self.agents[object_id].position = position
                self.agents[object_id].direction = direction

    def evaluate(self, env='pure', elapsed_time=0):
        """
        Evaluate the simulation run and save metrics to a JSON file.
        """
        # Create metrics directory if it doesn't exist
        os.makedirs('metrics', exist_ok=True)
        
        all_metrics = {}
        for agent in self.agents:
            # Check if agent is stuck by analyzing position history
            agent.get_stuck = False
            if len(self.position_history[agent.id]) > 60:  # At least 2 minutes of history (60 records * 2 seconds)
                # Get positions from last 2 minutes
                recent_history = self.position_history[agent.id][-60:]
                positions = [pos for _, pos in recent_history]
                
                # Calculate mean position
                mean_x = sum(pos.x for pos in positions) / len(positions)
                mean_y = sum(pos.y for pos in positions) / len(positions)
                mean_pos = Vector(mean_x, mean_y)
                
                # Calculate variance
                variance = sum((pos - mean_pos).length() for pos in positions) / len(positions)
                std_dev = variance ** 0.5  # Standard deviation
                
                # Calculate average displacement from mean position
                avg_displacement = sum((pos - mean_pos).length() for pos in positions) / len(positions)
                
                # If standard deviation is small, it means the agent is moving in a small area
                if std_dev < 300 and avg_displacement < 300:  # Threshold for standard deviation
                    agent.get_stuck = True

            metrics = {
                "success": agent.success,
                "time": round(agent.elapsed_time, 2) if agent.success else elapsed_time,
                "collision_count": agent.collision_count,
                "num_steps": agent.num_steps,
                "num_step_forward": agent.num_step_forward,
                "num_turn_around": agent.num_turn_around,
                "num_choose_waypoint": agent.num_choose_waypoint,
                "red_light_violation": agent.red_light_violation,
                "minimum_steps": agent.minimum_steps,
                "get_stuck": agent.get_stuck,
                "num_valid_choose_waypoint": agent.num_valid_choose_waypoint,
                "llm_call_time": agent.llm_call_time,
                "human_collision_count": agent.human_collision_count,
                "object_collision_count": agent.object_collision_count
            }
            all_metrics[agent.task_name] = metrics

        # Save all agents' metrics to a single JSON file with timestamp
        rule_based_str = 'rule' if self.config['pysbench.rule_based'] else self.model_name
        output_file = f"metrics/{self.config['pysbench.task_file'].split('/')[-1].split('.')[0]}_{self.config['pysbench.task_level']}_{env}_{rule_based_str}_{time.strftime('%Y%m%d_%H%M%S')}.json"
        with open(output_file, 'w') as f:
            json.dump(all_metrics, f, indent=4)
        
        self.logger.info(f"Metrics saved to {output_file}")

    def run_task(self, time_limit = None, env = 'pure'):
        with ThreadPoolExecutor(max_workers=self.config['pysbench.num_threads']) as executor:
            try:
                futures = []
                start_time = time.time()
                for agent in self.agents:
                    future = executor.submit(agent.run, self.exit_event)
                    futures.append(future)
                    self.position_history[agent.id] = [(start_time, agent.position)]

                last_check_time = start_time

                while True:
                    current_time = time.time()
                    
                    # Check if time limit is reached
                    if time_limit is not None and (current_time - start_time) >= time_limit:
                        elapsed_time = current_time - start_time
                        self.logger.info(f"Time limit of {time_limit} seconds reached. Stopping simulation...")
                        self.exit_event.set()
                        # Wait for all agent threads to finish without timeout
                        for future in futures:
                            try:
                                future.result()  
                            except Exception as e:
                                self.logger.error(f"Error while waiting for agent to finish: {e}")
                        self.evaluate(env = env, elapsed_time = elapsed_time)
                        break

                    # Check if all futures are done
                    if all(future.done() for future in futures):
                        elapsed_time = current_time - start_time
                        self.logger.info("All agents have finished their tasks.")
                        self.logger.info(f"Total time elapsed: {elapsed_time:.2f} seconds")
                        # Evaluate as successful completion
                        self.evaluate(env = env, elapsed_time = elapsed_time)
                        break

                    # Update position history every 2 seconds
                    if current_time - last_check_time >= self.check_interval:
                        for agent in self.agents:
                            # Add current position to history
                            if not agent.success:
                                self.position_history[agent.id].append((current_time, agent.position))
                        last_check_time = current_time

                    if self.traffic_controller is not None:
                        self.traffic_controller.intersection_manager.update_intersections(self.communicator)
                        self.traffic_controller.pedestrian_manager.update_pedestrians(self.communicator, self.traffic_controller.intersection_manager)

                    self.update_physical_states()

                    time.sleep(self.dt)

            except (KeyboardInterrupt, Exception) as e:
                elapsed_time = time.time() - start_time
                self.logger.info("Simulation interrupted")
                self.logger.info(f"Time elapsed before interruption: {elapsed_time:.2f} seconds")
                self.exit_event.set()
                # Wait for all agent threads to finish without timeout
                for future in futures:
                    try:
                        future.result()  
                    except Exception as e:
                        self.logger.error(f"Error while waiting for agent to finish: {e}")
                self.evaluate(env = env, elapsed_time = elapsed_time)


