import json
import math

import torch
import traci
import traci.constants as tc
import sumolib
import csv
import numpy as np
import datetime
import time
from collections import defaultdict

from app.streaming import RTXForword
from colorama import Fore
import traceback

from app import Config

from app.entitiy.CarRegistry import CarRegistry
from app.logdata import info
from app.routing.CustomRouter import CustomRouter
from app.streaming import RTXConnector
import time

# get the current system time
from app.routing.RoutingEdge import RoutingEdge
from app.entitiy.Car import Car

from app.datacollection.Datacollection import DataCollection
from app.network.Network import Network

from app.logdata import CSVLogger

current_milli_time = lambda: int(round(time.time() * 1000))


class VehicleExperienceCollector:
    """车辆经验数据收集器"""
    
    def __init__(self):
        # 存储每辆车的经验数据 {car_id: [(s_{t-1}, a_{t-1}, r, s_t), ...]}
        self.vehicle_experiences = defaultdict(list)
        # 存储每辆车的上一时刻状态
        self.previous_states = {}
        # 存储每辆车的上一时刻选择的将要进入的路段ID
        self.previous_actions = {}
        
    def collect_vehicle_experience(self, car_id, current_state, current_road, model_weights, edge_ids, car_id_map, tick):
        """
        收集单辆车的经验数据
        
        Args:
            car_id: 车辆ID
            current_state: 当前状态 (当前路段ID, 当前排队长度)
            current_road: 当前路段ID
            model_weights: 路段权重模型 [114, 1557]
            edge_ids: 路段ID列表
            car_id_map: 车辆ID映射字典
        """
        # 如果是第一次收集该车辆的数据
        if car_id not in self.previous_states:
            self.previous_states[car_id] = current_state
            return
        
        # 获取上一时刻的数据
        s_prev = self.previous_states[car_id]  # (上一时刻路段, 上一时刻排队长度)
        
        # 获取上一时刻选择的动作：上一时刻选择的将要进入的路段ID
        # 这里需要从车辆的路径规划中获取上一时刻选择的下一路段
        a_prev = self._get_previous_action(car_id, current_road)
        
        # 上一时刻的奖励：上一时刻选择的将要进入的路段的权重
        r_prev = 1.0  # 默认权重
        if a_prev in edge_ids and car_id in car_id_map:
            edge_index = edge_ids.index(a_prev)
            car_index = car_id_map[car_id]
            
            # 从 [114, 1557] 的张量中获取对应车辆和路段的权重
            if isinstance(model_weights, torch.Tensor):
                r_prev = model_weights[edge_index, car_index].item()
            else:
                r_prev = model_weights[edge_index][car_index] if edge_index < len(model_weights) and car_index < len(model_weights[edge_index]) else 1.0
        
        # 构建经验数据: (s_{t-1}, a_{t-1}, r, s_t)
        experience = {
            's_prev': s_prev,                    # 上一时刻状态: (上一时刻路段, 上一时刻排队长度)
            'a_prev': a_prev,                    # 上一时刻动作: 上一时刻选择的将要进入的路段ID
            'r_prev': r_prev,                    # 上一时刻奖励: 上一时刻选择的将要进入的路段的权重
            's_current': current_state,          # 当前状态: (当前路段, 当前排队长度)
            'timestamp': tick
        }
        # print(experience)
        # 如果上一时刻路段和当前路段相同，则不记录这条经验
        if s_prev[0] == current_state[0]:
            # 更新上一时刻数据
            self.previous_states[car_id] = current_state
            self.previous_actions[car_id] = self._get_current_next_road(car_id)
            return
        # 添加到该车辆的经验库
        self.vehicle_experiences[car_id].append(experience)
        
        # 更新上一时刻数据
        self.previous_states[car_id] = current_state
        # 存储当前选择的下一路段作为下一时刻的上一时刻动作
        self.previous_actions[car_id] = self._get_current_next_road(car_id)
    
    def _get_previous_action(self, car_id, current_road):
        """获取上一时刻选择的动作（将要进入的路段ID）"""
        # 如果车辆刚进入新路段，说明上一时刻选择的动作就是当前路段
        if car_id in self.previous_actions:
            return self.previous_actions[car_id]
        else:
            # 如果没有记录，使用当前路段作为上一时刻的动作
            return current_road
    
    def _get_current_next_road(self, car_id):
        """获取当前车辆选择的下一路段"""
        try:
            # 获取车辆的完整路径
            route = traci.vehicle.getRoute(car_id)
            current_road = traci.vehicle.getRoadID(car_id)
            
            # 找到当前路段在路径中的位置
            if current_road in route:
                current_index = route.index(current_road)
                # 获取下一路段
                if current_index + 1 < len(route):
                    return route[current_index + 1]
                else:
                    # 如果已经是最后一段，返回当前路段
                    return current_road
            else:
                return current_road
        except:
            # 如果获取失败，返回当前路段
            return traci.vehicle.getRoadID(car_id)
    
    def get_vehicle_experiences(self, car_id):
        """获取指定车辆的所有经验数据"""
        return self.vehicle_experiences.get(car_id, [])
    
    def get_all_experiences(self):
        """获取所有车辆的经验数据"""
        return dict(self.vehicle_experiences)
    
    def save_experiences_to_file(self, filename):
        """保存经验数据到文件"""
        import json
        with open(filename, 'w', encoding='utf-8') as f:
            json.dump(self.get_all_experiences(), f, ensure_ascii=False, indent=2)
    
    def clear_experiences(self):
        """清空所有经验数据"""
        self.vehicle_experiences.clear()
        self.previous_states.clear()
        self.previous_actions.clear()


class Simulation(object):
    """ here we run the simulation in """

    new_episode = True  # 用于判断是否是新的第一轮，如果是新的一轮，需清空CarRegistry.applyCarCounter CarRegistry.cars类中的车辆
    tick = 0
    Network.loadNetwork()

    edge_ids = list(Network.edgeIds)
    
    # 添加经验收集器
    experience_collector = VehicleExperienceCollector()

    @classmethod
    def applyFileConfig(cls):
        """ reads configs from a json and applies it at realtime to the simulation """
        try:
            config = json.load(open('./knobs.json'))
            # "the percentage of routes we want to use for exploration"
            # 我们想要用于探索的路由百分比
            CustomRouter.explorationPercentage = config['explorationPercentage']

            # "how much the averageEdgeFactor influences the routing"
            # 路段edge因子对路由的影响有多大
            CustomRouter.averageEdgeDurationFactor = config['averageEdgeDurationFactor']
            # "how much the length/speed influences the routing"
            # 长度/速度对路由的影响程度
            CustomRouter.maxSpeedAndLengthFactor = config['maxSpeedAndLengthFactor']
            # "how much the freshnessUpdateFactor influences the routing"
            # freshnessUpdateFactor对路由的影响程度
            CustomRouter.freshnessUpdateFactor = config['freshnessUpdateFactor']
            # "if data is older than this we do not consider it in the algorithm"
            # 如果数据早于此，我们就不考虑它在算法中
            CustomRouter.freshnessCutOffValue = config['freshnessCutOffValue']
            # 在汽车启动后寻找新路线每x次，多久更新一次路段
            # "check for a new route very x times after the car starts"
            CustomRouter.reRouteEveryTicks = config['reRouteEveryTicks']
        except:
            pass

    @classmethod
    def start(cls):
        """ start the simulation """
        info("# Start adding initial cars to the simulation", Fore.MAGENTA)
        # apply the configuration from the json file    json文件加载
        cls.applyFileConfig()
        info(Fore.GREEN + "# apply the configuration from the json file OK !" +  Fore.RESET)
        # CarRegistry.applyCarCounter()
        info(Fore.GREEN + "# start loop !" + Fore.RESET)
        traci.simulation.subscribe((tc.VAR_ARRIVED_VEHICLES_IDS,))
        # 清空经验收集器
        cls.experience_collector.clear_experiences()

    def collect_vehicle_data(cls, model, vehicle_data, car_id_map):
        """
        收集每辆车的经验数据
        """
        # 获取所有车辆ID
        vehicle_ids = traci.vehicle.getIDList()
        
        for car_id in vehicle_ids:
            try:
                # 获取车辆当前路段
                current_road = traci.vehicle.getRoadID(car_id)
                
                # 跳过在交叉路口中间的车辆
                if ":" in current_road or current_road == '':
                    continue

                # 计算到末端的距离，跳过不在末端位置的车辆
                Car_pos = traci.vehicle.getPosition(car_id)  # 当前车辆经纬度
                recent_lane_id = traci.vehicle.getLaneID(car_id)  # 车辆所在车道
                laneEnd_pos = traci.lane.getShape(recent_lane_id)[-1]  # 车辆所在车道末端位置
                # # 计算车辆到所在车道末端的距离
                distance_Car_Lane = math.sqrt(
                    (Car_pos[0] - laneEnd_pos[0]) ** 2 + (Car_pos[1] - laneEnd_pos[1]) ** 2)
                if distance_Car_Lane > 20:
                    continue
                # 获取当前路段排队长度
                queue_length = traci.edge.getLastStepVehicleNumber(current_road)
                # 构建当前状态 s: (当前路段ID, 当前排队长度)
                current_state = (current_road, queue_length)
                # print(current_state)
                
                # 收集经验数据
                cls.experience_collector.collect_vehicle_experience(
                    car_id, current_state, current_road, model, cls.edge_ids, car_id_map, cls.tick
                )
                
            except Exception as e:
                traceback.print_exc()
                # 忽略已删除的车辆

    def loop(cls, model, car_id_map, is_RL):
        vehicle_data = dict()
        cost_data = dict()
        # sumotick = 600
        # while cls.tick <= sumotick:
        # Do one simulation step
        cls.tick += 1
        # cls.tick += 1
        # 前进一步
        traci.simulationStep()
        # print(f"===================={cls.tick}=========================")
        # if cls.tick % 10 == 0:
        #     print(f"===================={cls.tick}=========================")

        # # 获取当前tick车道信息
        if cls.tick == 1:
            vehicle_data = DataCollection.get_orig_data(cls.tick, cls.edge_ids)
        else:
            vehicle_data = DataCollection.get_orig_data(cls.tick, cls.edge_ids)
        # if tick < Config.sumotick:
        # 每秒计算一次全局的权重cost
        cost = model
        i = 0
        for key, value in vehicle_data.items():
            cost_data[key] = cost[i]
            vehicle_data[key] = vehicle_data[key] + [cost[i]]
            i = i + 1


        # 检查SUMO仿真中的车辆列表，并将未注册的车辆添加到车辆注册表中，并进行仿真
        # print(cls.new_episode)

        # 迪杰斯特拉车辆寻路
        cls.new_episode = CarRegistry.applyCarCounter(cls.tick - 1, cost_data, cls.new_episode, car_id_map,is_RL)
        # （当前排队长度，下一动作（路段id），奖励r（我们生成的W），下一时刻排队长度）
        # 检查是否有被删除的车辆，并将这些车辆重新添加到系统中
        # 目的是确保模拟中所有的车辆都能被正确处理，即使它们在某些情况下被意外地从模拟中删除了。
        # 通过重新添加这些车辆并将它们的到达时间设置为当前tick
        # print(traci.simulation.getSubscriptionResults()[122])
        for removedCarId in traci.simulation.getSubscriptionResults()[122]:
            # 它使用CarRegistry类的findById方法找到相应的Car对象，并将其到达时间设置为当前tick
            # print(CarRegistry.findById(removedCarId))
            # 可以确保程序不会因为缺少车辆对象而崩溃。在这种情况下，使用NullCar.setArrived对象来替代真实的车辆对象，以便在不影响其他部分的情况下继续进行处理。
            CarRegistry.findById(removedCarId).setArrived(cls.tick)

        # CarRegistry.processTick(cls.tick, cost_data, car_id_map)
        
        # 收集每辆车的经验数据
        cls.collect_vehicle_data(model, vehicle_data, car_id_map)

        # 实时配置更新
        if (cls.tick % 10) == 0:
            if Config.kafkaUpdates is False and Config.mqttUpdates is False:
                # json mode  json模式
                # print("使用json配置")
                cls.applyFileConfig()

        # print status update if we are not running in parallel mode
        # 如果不是在并行模式下运行，则打印状态更新
        # if (cls.tick % 1) == 0 and Config.parallelMode is False:
            # print(str(Config.processID) + " -> Step:" + str(cls.tick) + " # Driving cars: " + str(
            #     traci.vehicle.getIDCount()) + "/" + str(
            #     CarRegistry.totalCarCounter) + " # avgTripDuration: " + str(
            #     CarRegistry.totalTrips) + " # avgTripOverhead: " + str(
            #     CarRegistry.totalTripOverheadAverage))

        # 判断循环是否结束
        loopOver = True
        for key in CarRegistry.cars:
            if not CarRegistry.cars[key].disabled:
                loopOver = False
                break
        # 获取当前日期和时间
        # import datetime
        # current_datetime = datetime.datetime.now()
        # # 打印当前日期和时间
        # # 以特定格式打印日期和时间
        # formatted_datetime = current_datetime.strftime('%Y%m%d_%H%M%S')
        # with open('./app/data/tickdata_'+str(formatted_datetime)+'.csv', mode='w', newline='') as file:
        #     writer = csv.writer(file)
        #     writer.writerows(train_data)
        return vehicle_data
    
    @classmethod
    def get_experience_data(cls):
        """获取所有车辆的经验数据"""
        return cls.experience_collector.get_all_experiences()
    
    @classmethod
    def save_experience_data(cls, filename):
        """保存经验数据到文件"""
        cls.experience_collector.save_experiences_to_file(filename)
    
    @classmethod
    def get_vehicle_experience(cls, car_id):
        """获取指定车辆的经验数据"""
        return cls.experience_collector.get_vehicle_experiences(car_id)


