#!/usr/bin/env python

import os
import carla
import numpy as np

from srunner.autoagents.sensor_interface import SensorInterface
from srunner.tools.route_manipulation import downsample_route

from agents.navigation.global_route_planner import GlobalRoutePlanner
from planning.networkx_astar_planner import AStarPlanner

from agent.autonomous_agent import AutonomousAgent
from agent.atomic_command_follower import AtomicCommandFollower
from agent.utils.agent_utils import convert_dict_location_to_carla_location
from comm.transceiver import Transceiver
from utils.partial_observable_captioner import PartialObservableCaptioner

class CommAgent(AutonomousAgent):
    """
    This class is to host the communication agent in the environment.
    (1) It defines sensors
    (2) Generate vehicle control and send message
    """
    def __init__(self,
                 vehicle,
                 agent_config,
                 ):
        self._agent_config = agent_config
        self._map = vehicle.get_world().get_map()
        self.agent_id = agent_config.name
        self.vehicle = vehicle
        self.sensor_devices = agent_config.sensors
        self.frame_rate = 20

        # Set up the Transceiver
        self.transceiver = Transceiver(vehicle=vehicle,
                                       agent=self,
                                       sharing_mode=True,
                                      )
        # Record the history of messages
        self.received_messages = []

        # Set up global plan
        self._global_plan = None
        self._global_plan_world_coord = None
        self._sampling_resolution = 0.5
        if agent_config.get("global_route_planner_type", None) in ["grp"]:
            self._global_planner_type = "grp"
            self._global_planner = GlobalRoutePlanner(self._map, self._sampling_resolution)
        else:
            self._global_planner_type = "astar"
            self._global_planner = AStarPlanner(world=vehicle.get_world(), ego_vehicle=vehicle, sampling_resolution=self._sampling_resolution)
        self._route = self.plan_route()

        # Set up sensor interface
        self.sensor_interface = SensorInterface()
        # Set up captioner
        self.captioner = PartialObservableCaptioner(vehicle,
                                                    perception_range=70,)
        # Set up language plan/goal/task
        self.task = agent_config.task
        # Set up the command follower
        self.command_follower = AtomicCommandFollower(vehicle)
        self.command_follower.set_global_plan(self._route)
        self.step_count = 0
        self.decision_frequency = 10

    def get_observation(self):
        """
        Get observation from sensors and messages
        """
        observation = self.sensor_interface.get_data()
        env_description = self.captioner.get_description()
        observation.update({"observation":env_description})
        observation.update({"task":self.task})
        
        # Get communication data
        communication_data = self.transceiver.get_lang_data()
    
        # Process received messages and record them with time stamps
        if self.step_count % self.decision_frequency == 0:
            if communication_data:
                for sender_id, message in communication_data.items():
                    # Assuming message has a 'language_message' attribute
                    self.received_messages.append({
                        'sender': str(sender_id),
                        'message': str(message.language_message),
                        'time': self.step_count  # Record the current time step
                    })

        # Update the observation with received messages dialog
        recent_messages = self.get_received_messages(num_messages=4)
        observation.update({"received_messages": recent_messages})
        return observation

    def run_step(self, action):
        """
        action is in the format of {"command":command, "message":message}
        """
        control = carla.VehicleControl()
        # Step 1: Take in Action {throttle, steer, brake, message}
        if action.get("command", None):
            control = self.command_follower.run_step(action["command"], self.step_count % self.decision_frequency, debug=True)
        # Step 2: Wrap Control to be safe
        # Step 3: Send message to other agents
        if action.get("message", None):
            self.transceiver.TX_Language(action["message"])

        # if self.step_count % self.decision_frequency == 0:
        #     if action.get("message", None):
        #         if len(self.received_messages) == 0 or self.received_messages[-1]['message'] != action["message"]:
        #             self.received_messages.append({
        #                 'sender': str(self.vehicle.id),
        #                 'message': str(action["message"]),
        #                 'time': self.step_count
        #             })

        # Step 4: Receive message from other agents, update the message history
        # self.message_history.append(self.transceiver.lang_channel.buffer)
        self.step_count += 1
        return control

    def get_received_messages(self, num_messages):
        """
        Returns the last num_messages received messages with their ages in seconds
        """
        messages_info = []
        current_step = self.step_count

        for msg in self.received_messages[-num_messages:]:
            message_age_steps = current_step - msg['time'] + self.decision_frequency
            message_age_seconds = np.round(message_age_steps / self.frame_rate, 2)
            if message_age_seconds <= 2:
                messages_info.append({
                    'sender': 'Vehicle ' + str(msg['sender']),
                    'message': str(msg['message']),
                    'age_seconds': str(message_age_seconds),
                })
        return messages_info

    def set_global_plan(self, global_plan_gps, global_plan_world_coord):
        """
        Set global plan (route) for the agent
        """
        ds_ids = downsample_route(global_plan_world_coord, 1)
        self._global_plan_world_coord = [(global_plan_world_coord[x][0], global_plan_world_coord[x][1])
                                         for x in ds_ids]
        self._global_plan = [global_plan_gps[x] for x in ds_ids]

    def sensors(self):  # pylint: disable=no-self-use
        """
        :return: a list of sensors attached to the agent
        """
        sensors = []
        if "Center" in self.sensor_devices:
            sensor_spec = self.sensor_devices["Center"]
            sensors.append(
            {"type": "sensor.camera.rgb", "x": 0.7, "y": 0.0, "z": 1.60, "roll": 0.0, "pitch": 0.0, "yaw": 0.0,
             "width": float(sensor_spec.camera_width), "height": float(sensor_spec.camera_height), "fov": 100, "id": "Center"},
            )
        return sensors

    def plan_route(self):
        if self._global_planner_type == "grp":
            start = convert_dict_location_to_carla_location(self._agent_config.start)
            target = convert_dict_location_to_carla_location(self._agent_config.target)
            return self._global_planner.trace_route(start, target)
        elif self._global_planner_type == "astar":
            route = []
            for i in range(len(self._agent_config.route)):
                location = convert_dict_location_to_carla_location(self._agent_config.route[i])
                # wp = self._map.get_waypoint(start)
                route.append(location)
            return self._global_planner.trace_route(route)
        return []
