import random
import os
import json
import time
import yaml
from typing import List
from threading import Lock
from Base.Types import Vector, Road
from Config import Config
from Base import DeliveryMan, Map, Order, Node, Edge
import matplotlib.pyplot as plt
from Base.Map import Map
from persona_hub.code.persona_utils import PersonaManager
from Communicator import DeliveryCommunicator
from typing import Optional
import logging

logger = logging.getLogger(__name__)

class DeliveryManager:
    def __init__(self,
                num_customers: Optional[int] = None,
                num_stores: Optional[int] = None,
                num_delivery_men: Optional[int] = None,
                dataset_path: Optional[str] = None,
                input_folder: Optional[str] = None,
                communicator: Optional[DeliveryCommunicator] = None
            ):
        self.num_customers = num_customers
        self.num_stores = num_stores
        self.num_delivery_men = num_delivery_men
        self.dataset_path = dataset_path
        self.input_folder = input_folder
        self.map = Map()
        self.communicator = communicator
        self._customers: List[Vector] = []  # readonly
        self._stores: List[Vector] = []  # readonly
        self.delivery_men: List[DeliveryMan] = []
        self.orders: List[Order] = []
        self._supply_points: List[Vector] = []  # readonly
        self._delivery_men: List[Vector] = []  # readonly

        self.difficulty = Config.DIFFICULTY
        self.lock = Lock()
        self.persona_manager = PersonaManager(Config.NORMAL_PERSONA_PATH)
        self.notification = []
        self.order_count = 0

        self.initialize()

        logger.warning(f"DeliveryManager initialized, difficulty: {self.difficulty}")
        for delivery_man in self.delivery_men:
            logger.warning(f"""
            DeliveryMan {delivery_man.id} initialized:
            llm: {delivery_man.llm.model_name}
            persona: {delivery_man.persona}
            init energy: {delivery_man.energy}
            init money: {delivery_man.money}
            init speed: {delivery_man.move_speed}
            ------------------------------------------""")

    def create_orders(self):
        store_order_count = {store: 0 for store in self.stores}
        hungary_delivery_men = [dm for dm in self.delivery_men if len(dm.orders) < Config.MAX_ORDERS]
        hungary_rate = len(hungary_delivery_men) / len(self.delivery_men)
        with self.lock:
            if len(self.orders) >= 10 or self.order_count >= 2000 or hungary_rate <= Config.HUNGARY_RATE:
                return
            else:
                for order in self.orders:
                    store_order_count[order.store_position] += 1
                selected_store = min(store_order_count, key=store_order_count.get)
                max_iterations = 1000
                iterations = 0
                while iterations < max_iterations:
                    iterations += 1
                    selected_customer = random.choice(self.customers)
                    distance = selected_customer.distance(selected_store)
                    if self.difficulty == 'easy' and distance <= 7000:
                        break
                    elif self.difficulty == 'medium' and distance <= 30000 and distance > 7000:
                        break
                    elif self.difficulty == 'hard' and distance <= 150000 and distance > 30000:
                        break
                    elif self.difficulty == 'insane':
                        break

                new_order = Order(selected_customer, selected_store, self.map)
                self.orders.append(new_order)
                self.order_count += 1
                # logger.warning(f"New order created: {new_order}, global order count: {self.order_count}, current order count: {len(self.orders)}")

    def load_communicator(self, communicator: DeliveryCommunicator):
        self.communicator = communicator
        for delivery_man in self.delivery_men:
            delivery_man.communicator = communicator

    def get_orders(self):
        '''
        Get the orders
        '''
        with self.lock:
            return self.orders.copy()

    def get_delivery_men(self):
        '''
        Get the delivery men
        '''
        return self.delivery_men

    def add_shared_order(self, order: Order):
        '''
        Add a shared order
        '''
        with self.lock:
            self.orders.append(order)

    def cancel_shared_order(self, order: Order):
        '''
        Cancel a shared order
        '''
        with self.lock:
            if order in self.orders and order.is_shared:
                self.orders.remove(order)
                return True
            else:
                return False

    def try_to_allocate_order(self):
        # 1. 先获取需要处理的订单信息
        orders_to_process = []
        # print("allocate order, lock is acquired")
        with self.lock:
            for order in self.orders:
                now = time.time()
                if hasattr(order, 'start_time') and now - order.start_time >= random.randint(20, 60):
                    if hasattr(order, 'bids') and order.bids:
                        delivery_man, bid_price = min(order.bids, key=lambda x: x[1])
                        orders_to_process.append((order, delivery_man, bid_price))
                    self.orders.remove(order)
        # print("allocate order, lock is released")

        for i, (order, delivery_man, bid_price) in enumerate(orders_to_process):
            order.sale_price = min(order.max_sale_price, bid_price)
            order.add_delivery_man(delivery_man)
            order.start_time = time.time()
            delivery_man.add_order(order)
            logger.warning(f"DeliveryMan {delivery_man} has won the bid for order {order.id} with price {bid_price}. participated: {order.bids}")

    def bid_order(self, order: Order, delivery_man: DeliveryMan, bid_price: float):
        '''
        Bid the order
        '''
        with self.lock:
            if order in self.orders:
                order.bid_order(delivery_man, bid_price)
                return True
            else:
                return False

    def finish_order(self, order: Order, delivery_man_id: int):
        '''
        Finish the order
        '''
        print(f"Mission success, order {order} is finished by DeliveryMan {delivery_man_id}")
        order.distribute_payment()
        self.update_notification(f"DeliveryMan {delivery_man_id} has delivered order {order}!")

    ##############################################################
    #################  Initialize the platform  ##################
    ##############################################################
    def initialize(self):
        '''
        Initialize the platform, customers, delivery men, and stores from the map
        '''
        ############# sample datapoints #############
        if self.dataset_path is not None:
            self.dataset_initialize()
            logger.warning(f"Map loaded from {self.dataset_path}, dataset initialize, difficulty: {self.difficulty} from Dataset Config")
        else:
            self.default_initialize()
            logger.warning(f"Map loaded from {self.input_folder}, default initialize, difficulty: {self.difficulty} from Global Config")

        ############# sample delivery men #############
        for delivery_man in self.delivery_men:
            # persona = self.persona_manager.get_random_persona()
            persona = self.persona_manager.get_persona_by_id(0)
            # persona = self.persona_manager.get_persona_by_id(delivery_man.id%10)
            delivery_man.load_persona(persona)
            self.persona_manager.persona_cache[delivery_man.id] = persona # remember to update the caches
        self.persona_manager.to_json(Config.SAMPLE_PERSONA_PATH)

    def dataset_initialize(self):
        '''
        Initialize the platform, customers, delivery men, and stores from the map
        '''
        # TODO: make it more general, now it is fixed to read from maps_dataset/map_000 to maps_dataset/map_009

        roads_path = os.path.join(self.dataset_path, 'roads.json')
        metadata_path = os.path.join(self.dataset_path, 'metadata.json')
        position_path = os.path.join(self.dataset_path, 'positions.json')

        ############# load roads #############
        self.map.import_map(roads_path)

        with open(metadata_path, 'r') as f:
            metadata = json.load(f)

        self.difficulty = metadata['difficulty']
        self.num_customers = metadata['num_customers']
        self.num_stores = metadata['num_stores']
        self.num_delivery_men = metadata['num_delivery_men']

        ############# load positions #############
        with open(position_path, 'r') as f:
            positions = json.load(f)

        self._supply_points = self.map.get_supply_points()
        self._stores = [Vector(pos['x'], pos['y']) for pos in positions['stores']]
        self._customers = [Vector(pos['x'], pos['y']) for pos in positions['customers']]
        self._delivery_men = [Vector(pos['x'], pos['y']) for pos in positions['delivery_men']]
        # self._delivery_men = self._delivery_men[:self.num_delivery_men]

        print(len(self._delivery_men))
        self.delivery_men = [DeliveryMan(
            pos,
            Vector(0, 0),
            use_planner=Config.USE_A2A_PLANNER,
            communicator=self.communicator,
            delivery_manager=self
        ) for pos in self._delivery_men]

    def default_initialize(self):
        '''
        Initialize the platform, customers, delivery men, and stores from the map
        '''
        roads_file_path = os.path.join(self.input_folder, 'roads.json')
        self.map.import_map(roads_file_path)

        with open(roads_file_path, 'r') as file:
            data = json.load(file)

        roads = data.get('roads', [])
        road_segments_number = len(roads)
        # if road_segments_number <= 5:
        #     self.difficulty = 'easy'
        # elif road_segments_number >= 6 and road_segments_number <= 8:
        #     self.difficulty = 'medium'
        # elif road_segments_number >= 9 and road_segments_number <= 10:
        #     self.difficulty = 'hard'
        # else:
        #     self.difficulty = 'insane'
        self.difficulty = Config.DIFFICULTY

        ############# sample stores #############
        for i in range(self.num_stores):
            self._stores.append(self.map.get_random_node(exclude_pos=self._stores))

        ############# sample customers #############
        if self.difficulty == 'easy':
            for i in range(self.num_customers):
                self._customers.append(self.map.get_random_node_with_edge_distance(base_pos=self._stores, exclude_pos=self._stores + self._customers, min_distance=1, max_distance=4)[0])
        elif self.difficulty == 'medium' or self.difficulty is None:
            for i in range(self.num_customers):
                self._customers.append(self.map.get_random_node_with_edge_distance(base_pos=self._stores, exclude_pos=self._stores + self._customers, min_distance=3, max_distance=10)[0])
        elif self.difficulty == 'hard':
            for i in range(self.num_customers):
                self._customers.append(self.map.get_random_node_with_distance(base_pos=self._stores, exclude_pos=self._stores + self._customers, min_distance=10000, max_distance=100000))
        if self.difficulty == 'insane':
            for i in range(self.num_customers):
                self._customers.append(self.map.get_random_node(exclude_pos=self._stores + self._customers))

        ############# sample delivery men #############
        if self.difficulty == 'easy':
            for i in range(self.num_delivery_men):
                self.delivery_men.append(DeliveryMan(
                    self.map.get_random_node_with_edge_distance(base_pos=self._stores, exclude_pos=None, min_distance=1, max_distance=4)[0].position,
                    Vector(0, 0),
                    use_planner=Config.USE_A2A_PLANNER,
                    communicator=self.communicator,
                    delivery_manager=self
                ))
        elif self.difficulty == 'medium' or self.difficulty is None:
            for i in range(self.num_delivery_men):
                self.delivery_men.append(DeliveryMan(
                    self.map.get_random_node_with_edge_distance(base_pos=self._stores, exclude_pos=None, min_distance=3, max_distance=10)[0].position,
                    Vector(0, 0),
                    use_planner=Config.USE_A2A_PLANNER,
                    communicator=self.communicator,
                    delivery_manager=self
                ))
        elif self.difficulty == 'hard':
            for i in range(self.num_delivery_men):
                self.delivery_men.append(DeliveryMan(
                    self.map.get_random_node_with_distance(base_pos=self._stores, exclude_pos=None, min_distance=10000, max_distance=100000).position,
                    Vector(0, 0),
                    use_planner=Config.USE_A2A_PLANNER,
                    communicator=self.communicator,
                    delivery_manager=self
                ))
        elif self.difficulty == 'insane':
            for i in range(self.num_delivery_men):
                self.delivery_men.append(DeliveryMan(
                    self.map.get_random_node(exclude_pos=None).position,
                    Vector(0, 0),
                    use_planner=Config.USE_A2A_PLANNER,
                    communicator=self.communicator,
                    delivery_manager=self
                ))

        ############# get points #############
        self._supply_points = self.map.get_supply_points()
        self._customers = [customer.position for customer in self._customers]
        self._stores = [store.position for store in self._stores]
        self._delivery_men = [delivery_man.position for delivery_man in self.delivery_men]

    def update_notification(self, notification: str):
        with self.lock:
            self.notification.append(notification)
            if len(self.notification) > 5:
                self.notification.pop(0)

    def get_notification(self):
        str_notification = ""
        for notification in self.notification:
            str_notification += notification + "\n"
        return str_notification

    def get_step(self):
        '''
        Get the step
        '''
        step = 0
        for delivery_man in self.delivery_men:
            step += delivery_man.get_step()
        return step

    def allocate_llm(self, config_file: str):
        '''
        Allocate the LLM to the agents
        '''
        pass

    @property
    def customers(self):
        return self._customers

    @property
    def stores(self):
        return self._stores

    @property
    def supply_points(self):
        return self._supply_points

    def to_json(self):
        """
        Convert the DeliveryManager instance to a JSON object
        """
        with self.lock:
            data = {
                "map": {
                    "nodes": [
                        {
                            "position": {"x": node.position.x, "y": node.position.y},
                            "type": node.type
                        }
                        for node in self.map.nodes
                    ],
                    "edges": [
                        {
                            "node1": {"x": edge.node1.position.x, "y": edge.node1.position.y},
                            "node2": {"x": edge.node2.position.x, "y": edge.node2.position.y}
                        }
                        for edge in self.map.edges
                    ]
                },
                "customers": [{"x": c.x, "y": c.y} for c in self.customers],
                "stores": [{"x": s.x, "y": s.y} for s in self.stores],
                "delivery_men": [
                    {
                        "id": dm.id,
                        "position": {"x": dm.position.x, "y": dm.position.y},
                        "direction": {"x": dm.direction.x, "y": dm.direction.y},
                        "state": str(dm.get_state()),
                        "energy": dm.get_energy(),
                        "speed": dm.get_speed()
                    }
                    for dm in self.delivery_men
                ],
                "orders": [
                    {
                        "id": order.id,
                        "sale_price": order.sale_price,
                        "customer_position": {"x": order.customer_position.x, "y": order.customer_position.y},
                        "store_position": {"x": order.store_position.x, "y": order.store_position.y},
                        "has_picked_up": order.has_picked_up,
                        "has_delivered": order.has_delivered,
                        "estimated_time": order.estimated_time,
                        "is_shared": order.is_shared,
                        "meeting_point": {"x": order.meeting_point.x, "y": order.meeting_point.y} if order.meeting_point else None,
                        "spent_time": time.time() - order.start_time
                    }
                    for order in self.orders
                ],
                "notification": self.get_notification()
            }
            return data

    def visualization(self, save_path: Optional[str] = None):
        """
        Visualize the map, including nodes, edges, customers, stores, supply points, and delivery men.

        Parameters:
            dm: DeliveryManager instance
            save_path: Optional, the path to save the image. If None, the image will be displayed
        """
        map_obj: Map = self.map

        plt.figure(figsize=(12, 8))

        # draw edges
        for edge in map_obj.edges:
            x_coords = [edge.node1.position.x, edge.node2.position.x]
            y_coords = [edge.node1.position.y, edge.node2.position.y]
            plt.plot(x_coords, y_coords, 'gray', alpha=0.5, linewidth=1)

        # collect different types of nodes
        normal_nodes = []
        intersection_nodes = []
        supply_nodes = []

        for node in map_obj.nodes:
            if node.type == "normal":
                normal_nodes.append(node)
            elif node.type == "intersection":
                intersection_nodes.append(node)
            elif node.type == "supply":
                supply_nodes.append(node)

        # draw different types of map nodes
        if normal_nodes:
            x_coords = [node.position.x for node in normal_nodes]
            y_coords = [node.position.y for node in normal_nodes]
            plt.scatter(x_coords, y_coords, c='blue', s=80, label='Normal Nodes')

        if intersection_nodes:
            x_coords = [node.position.x for node in intersection_nodes]
            y_coords = [node.position.y for node in intersection_nodes]
            plt.scatter(x_coords, y_coords, c='red', s=100, label='Intersections')

        if supply_nodes:
            x_coords = [node.position.x for node in supply_nodes]
            y_coords = [node.position.y for node in supply_nodes]
            plt.scatter(x_coords, y_coords, c='green', s=100, marker='*', label='Supply Points')

        # visualize the dynamic entities of DeliveryManager
        def plot_vectors(vectors, color, label, marker='o', size=80):
            if vectors:
                x_coords = [v.x for v in vectors]
                y_coords = [v.y for v in vectors]
                plt.scatter(x_coords, y_coords, c=color, s=size, marker=marker, label=label)

        plot_vectors(self._stores, 'purple', 'Stores', marker='o', size=100)
        plot_vectors(self._delivery_men, 'cyan', 'Delivery Men', marker='x', size=50)
        plot_vectors(self._customers, 'orange', 'Customers', marker='+', size=50)

        plt.title('Delivery Manager Map Visualization')
        plt.xlabel('X Coordinate')
        plt.ylabel('Y Coordinate')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.axis('equal')

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            plt.close()
        else:
            plt.show()

    def get_dict(self):
        return {
            "orders": [order.get_dict() for order in self.orders] if self.orders else [],
            "notification": self.get_notification(),
        }

