import numpy as np
import random

from utils import generate_cvrp_instance


class Event:
    def __init__(self, time, vehicle_id, target_node):
        self.time = time
        self.vehicle_id = vehicle_id
        self.target_node = target_node

    def __lt__(self, other):
        return self.time < other.time

class EventManager:
    def __init__(self):
        self.events = []  # 使用简单列表而不是优先队列
        self.current_time = 0.0
        self.events_his = []
        
    def add_event(self, event):
        """添加事件"""
        self.events.append(event)
        self.events.sort(key=lambda e: e.time)  # 保持时间顺序
        
    def get_and_exe_next_event(self):
        """获取下一个事件"""
        if self.events:
            return self.events.pop(0)
        return None
        
    def advance_time(self, new_time):
        """推进时间"""
        self.current_time = max(self.current_time, new_time)
        
    def clear(self):
        """清空所有事件"""
        self.events = []
        self.events_his = []
        self.current_time = 0.0

class Vehicle:    
    def __init__(self, vehicle_id, capacity, speed):
        self.id = vehicle_id
        self.capacity = capacity
        self.speed = speed
        self.current_load = 0.0
        self.position = 0  # 当前位置（节点ID，0为depot）
        self.route = [0]  # 路径历史，从depot开始
        self.available_time = 0.0  # 车辆可用时间
        self.is_busy = False  # 是否在移动中
        self.is_finished = False  # 是否已完成任务（回到depot后不能再出发）
        self.type = 'Vehicle'
        
    def can_serve_customer(self, customer_demand):
        """检查是否能服务指定需求的客户"""
        return self.current_load + customer_demand <= self.capacity
        
    def load_customer(self, demand):
        """装载客户货物"""
        if self.can_serve_customer(demand):
            self.current_load += demand
            return True
        return False
        
    def start_move_to(self, node_id, travel_time, current_time):
        """开始移动到指定节点"""
        self.available_time = current_time + travel_time
        self.is_busy = True
        return self.available_time
        
    def finish_move_to(self, node_id):
        """完成移动"""
        self.position = node_id
        self.route.append(node_id)
        self.is_busy = False
        
        # 回到depot自动卸货并标记为完成
        if node_id == 0:
            self.current_load = 0.0
            # 一旦回到depot就完成任务，不允许再次出发
            self.is_finished = True

    def get_state(self):
        """获取车辆状态向量"""
        return [
            float(self.id),
            float(self.position), 
            float(self.current_load),
            float(self.capacity),
            float(self.available_time),
            float(1.0 if self.is_busy else 0.0)  # 忙碌状态
        ]
        
    def reset(self):
        """重置车辆状态"""
        self.current_load = 0.0
        self.position = 0
        self.route = [0]
        self.available_time = 0.0
        self.is_busy = False
        self.is_finished = False


class CVRPEnvironment:    
    def __init__(self, 
                 distance_matrix,
                 customer_demands,
                 n_vehicles,
                 vehicle_capacity,
                 vehicle_speed):
        """        
        Args:
            distance_matrix: 距离矩阵 (n_nodes x n_nodes)
            customer_demands: 客户需求字典 {customer_id: demand}
            n_vehicles: 车辆数量
            vehicle_capacity: 车辆容量
            vehicle_speed: 车辆速度
        """
        self.distance_matrix = distance_matrix.copy()
        self.customer_demands = customer_demands.copy()
        self.n_nodes = distance_matrix.shape[0]
        self.n_customers = self.n_nodes - 1  # 除去depot
        self.n_vehicles = n_vehicles
        
        # 创建车队
        self.vehicles = []
        for i in range(self.n_vehicles):
            self.vehicles.append(Vehicle(i, vehicle_capacity, vehicle_speed))
            
        # 事件管理器
        self.event_manager = EventManager()
        
        # 环境状态
        self.unserved_customers = list(range(1, self.n_nodes))
        self.served_customers = []
        self.available_actions = []
        self.current_vehicle_idx = 0  # 当前活跃车辆
        self.done = False
        
        # 统计信息
        self.total_distance = 0.0
        
    def reset(self):
        """重置环境"""
        # 重置所有车辆
        for vehicle in self.vehicles:
            vehicle.reset()
            
        # 重置事件管理器
        self.event_manager.clear()
        
        # 为所有车辆添加初始事件
        for vehicle in self.vehicles:
            ready_event = Event(0.0, vehicle.id, None)
            self.event_manager.add_event(ready_event)
            
        # 重置环境状态
        self.unserved_customers = list(range(1, self.n_nodes))
        self.served_customers.clear()
        self.current_vehicle_idx = 0
        self.done = False
        self.total_distance = 0.0
        self._update()
        
        return self.get_state()
        
    def _process_next_event(self):
        """处理下一个事件"""
        next_event = self.event_manager.get_and_exe_next_event()
        self.event_manager.events_his.append(next_event)
        if next_event:
            # 更新时间
            self.event_manager.advance_time(next_event.time)
            # 设置当前活跃车辆
            self.current_vehicle_idx = next_event.vehicle_id
            # 如果有目标节点，完成移动
            if next_event.target_node is not None:
                vehicle = self.vehicles[next_event.vehicle_id]
                target_node = next_event.target_node
                vehicle.finish_move_to(target_node)
        
    def get_available_actions(self, current_vehicle_idx):
        """获取当前车辆的可行动作""" 
        current_vehicle = self.vehicles[current_vehicle_idx]
        
        # 如果车辆已完成任务，不能再有任何动作
        if current_vehicle.is_finished:
            return []
        
        available_actions = []
        # 检查未服务的客户
        for customer_id in self.unserved_customers:
            customer_demand = self.customer_demands.get(customer_id, 0)
            if current_vehicle.can_serve_customer(customer_demand):
                available_actions.append(customer_id)
        # 如果剩余的节点数比较少，可以允许一部分卡车先返回depot
        if len(self.unserved_customers) < self.n_vehicles:
            available_actions.append(0)

        # 如果没有可服务的客户，只能回depot
        if not available_actions:
            available_actions = [0]
            
        return available_actions
        
    def step(self, action):
        current_vehicle = self.vehicles[self.current_vehicle_idx]
        
        # 验证action有效性
        if action not in self.available_actions:
            return self.get_state(), -np.inf, self.done, {"error": "Invalid action"}
        
        # 计算移动距离
        distance = self.distance_matrix[current_vehicle.position][action]
        self.total_distance += distance

        travel_time = distance * sine_with_noise_at_time(self.event_manager.current_time)
        
        # 移动逻辑：移动并更新时间
        arrival_time = current_vehicle.start_move_to(action, travel_time, self.event_manager.current_time)
        
        # 如果是服务客户，先处理货物装载
        if action != 0:
            customer_demand = self.customer_demands.get(action, 0)
            if current_vehicle.can_serve_customer(customer_demand):
                current_vehicle.load_customer(customer_demand)
                self.served_customers.append(action)
                self.unserved_customers.remove(action)
                self.event_manager.events.append(Event(arrival_time, current_vehicle.id, action))
            else:
                return self.get_state(), -np.inf, self.done, {"error": "Capacity exceeded"}
        else:
            self.event_manager.events.append(Event(arrival_time, current_vehicle.id, action))

        # 计算奖励
        distance_new = distance * sine_with_noise_at_time(self.event_manager.current_time)
        reward = self.distance_matrix.max() - distance_new
        self._update()
        return self.get_state(), reward, self.done, {"distance": distance_new, "vehicle": self.current_vehicle_idx}
    
    def _update(self):
        self._process_next_event() # finish move
        self.available_actions = self.get_available_actions(self.current_vehicle_idx)
        
        # 检查完成条件 - 分析开始和结束阶段
        if not self.unserved_customers:  # 所有客户已服务
            # 结束阶段：检查是否所有车辆都已回到depot并完成任务
            all_vehicles_finished = all(vehicle.is_finished for vehicle in self.vehicles)
            if all_vehicles_finished:
                self.done = True
        else:
            # 开始/进行阶段：检查是否还有车辆能接受任务
            if not self.available_actions:
                if self.event_manager.events:
                    self._update()
                else:
                    self.done = True
        # 车辆的结束事件
        finish_events = []
        for event in self.event_manager.events:
            if event.target_node == 0:
                finish_events.append(event)
        for event in finish_events:
            self.event_manager.events.remove(event)
            vehicle = self.vehicles[event.vehicle_id]
            target_node = event.target_node
            vehicle.finish_move_to(target_node)
            
    def get_state(self):
        """获取环境状态向量"""
        state = []
        
        # 当前车辆信息
        current_vehicle = self.vehicles[self.current_vehicle_idx]
        state.extend(current_vehicle.get_state())
        
        # 已服务客户的位掩码
        for customer_id in range(1, self.n_nodes):
            state.append(1.0 if customer_id in self.served_customers else 0.0)
            
        # 剩余客户需求信息
        for customer_id in range(1, self.n_nodes):
            if customer_id not in self.served_customers:
                demand = self.customer_demands.get(customer_id, 0)
                state.append(demand / 50.0)  # 归一化需求
            else:
                state.append(0.0)
                
        # 全局统计
        state.extend([
            len(self.served_customers) / self.n_customers,  # 完成率
            self.current_vehicle_idx / len(self.vehicles),  # 当前车辆索引归一化
            self.total_distance / 1000.0  # 归一化总距离
        ])
        
        return state + [current_vehicle.position]


if __name__ == '__main__':
    n_customers = 10
    print(f"=== {n_customers}个客户CVRP测试 ===")
    # 智能生成CVRP实例
    distance_matrix, customer_demands, coordinates = generate_cvrp_instance(n_customers)

    print(f"节点配置: 1个depot + {n_customers}个客户")
    print(f"客户需求范围: {min(customer_demands.values())}-{max(customer_demands.values())}")
    print(f"总需求量: {sum(customer_demands.values())}")

    # 智能配置车队 (根据客户数量调整)
    n_vehicles = 2
    vehicle_capacity = max(50, sum(customer_demands.values()) // n_vehicles * 1.2)  # 容量稍大于平均需求

    print(f"智能车队配置: {n_vehicles}辆车，每辆容量{vehicle_capacity:.0f}")
    print(f"总容量: {n_vehicles * vehicle_capacity:.0f}")
    print()

    # 创建环境
    env = CVRPEnvironment(
        distance_matrix=distance_matrix,
        customer_demands=customer_demands,
        n_vehicles=n_vehicles,
        vehicle_capacity=int(vehicle_capacity),
        vehicle_speed=1.0
    )

    state = env.reset()
    total_reward = 0
    step_count = 0

    print("初始状态:")
    print()

    max_steps = min(200, n_customers * 3)  # 动态调整最大步数

    while not env.done and step_count < max_steps:
        action = random.choice(env.available_actions)
        state, reward, done, info = env.step(action)
        total_reward += reward
        step_count += 1
    print()
    print(f"\n=== 测试完成 ===")
    print(f"总步数: {step_count}")
    print(f"总奖励: {total_reward:.2f}")
    print(f"总距离: {env.total_distance:.1f}")
    print(f"环境完成: {env.done}")

    # 验证所有车辆都回到depot
    print(f"\n=== 最终车辆状态验证 ===")
    all_at_depot = True
    all_finished = True

    for i, vehicle in enumerate(env.vehicles):
        at_depot = vehicle.position == 0
        finished = vehicle.is_finished
        print(
            f"车辆 {i}: 位置={vehicle.position}, 载量={vehicle.current_load:.1f}/{vehicle.capacity}, 在depot={at_depot}, 已完成={finished}")
        print(f"  路径: {' -> '.join(map(str, vehicle.route))}")

        if not at_depot:
            all_at_depot = False
        if not finished:
            all_finished = False

    print(f"\n服务统计:")
    print(f"已服务客户: {len(env.served_customers)}/{env.n_customers}")
    print(f"剩余客户: {env.unserved_customers}")
    print(f"所有车辆都在depot: {all_at_depot}")
    print(f"所有车辆都已完成: {all_finished}")