import traceback
from random import random

import torch
import traci
from dijkstar import Graph, find_path
from dijkstar.algorithm import PathInfo
from app.network.Network import Network
from app.routing.RouterResult import RouterResult
import networkx as nx
import numpy as np
import sumolib

class CustomRouter(object):
    """ our own custom defined router """

    # Empty starting references
    edgeMap = None
    graph = None
    graph_node = None
    graph_edge = None
    dqn_model = {}
    dqn_edge_ids = None
    net = None  # 添加网络对象

    # the percentage of smart cars that should be used for exploration 应用于探索的智能汽车的百分比
    explorationPercentage = 0.0  # INITIAL JSON DEFINED!!!
    # randomizes the routes 随机化路线
    routeRandomSigma = 0.2  # INITIAL JSON DEFINED!!!
    # how much speed influences the routing 速度在多少程度上影响了路由
    maxSpeedAndLengthFactor = 1  # INITIAL JSON DEFINED!!!
    # multiplies the average edge value 乘以平均边缘值
    averageEdgeDurationFactor = 1  # INITIAL JSON DEFINED!!!
    # how important it is to get new data 去取得新的数据的重要性
    freshnessUpdateFactor = 10  # INITIAL JSON DEFINED!!!
    # defines what is the oldest value that is still a valid information 定义仍然是有效信息的最古老的值
    freshnessCutOffValue = 500.0  # INITIAL JSON DEFINED!!!
    # how often we reroute cars 我们多久重新路由汽车
    reRouteEveryTicks = 1  # INITIAL JSON DEFINED!!!

    @classmethod
    def init(self):
        """ set up the router using the already loaded network """
        # 加载网络
        self.net = sumolib.net.readNet("./app/map/ChongQing-114.net.xml")
        self.graph_node = Graph()
        self.graph_edge = Graph()
        self.edgeMap = {}
        # 每个edge及其可到达的相邻edge
        for edge in Network.routingEdges:
            # print(f"=============={edge}==============")
            self.edgeMap[edge.id] = edge
            self.graph_node.add_edge(edge.fromNodeID, edge.toNodeID,
                                {'length': edge.length, 'maxSpeed': edge.maxSpeed,
                                 'lanes': len(edge.lanes), 'edgeID': edge.id})
            # 每个node及其可到达的相邻node，及其所对应的egde
            self.nx_graph_edge = nx.DiGraph()
            for edge in Network.edges:
                outgoing_edges = edge.getOutgoing()
                for outgoing_edge in outgoing_edges:
                    # 处理连接关系
                    self.graph_edge.add_edge(edge.getID(), outgoing_edge.getID(),
                                             {'from_edgeID': edge.getID()})
                    # print(edge.getID(), "->", outgoing_edge.getID())
                    self.nx_graph_edge.add_edge(
                        edge.getID(),
                        outgoing_edge.getID(),
                        weight=1,  # 占位
                        from_edgeID=outgoing_edge.getID()
                    )

    @classmethod
    def minimalRoute(cls, fr, to, tick, car):
        # 最小开销
        """creates a minimal route based on length / speed  """
        cost_func = lambda u, v, e, prev_e: e['length'] / e['maxSpeed']
        route = find_path(cls.graph_node, fr, to, cost_func=cost_func)
        return RouterResult(route, False)

    @classmethod
    def route(cls, fr, to, tick, car, cost_data,is_RL):
        """ creates a route from the f(node) to the t(node) """
        # 1) SIMPLE COST FUNCTION 最大开销 routeRandomSigma 0.2
        # cost_func = lambda u, v, e, prev_e: max(0,gauss(1, CustomRouter.routeRandomSigma) \
        #                                         * (e['length']) / (e['maxSpeed']))

        # if car. :
        #     # here we reduce the cost of an edge based on how old our information is
        #     print("victim routing!")
        #     cost_func = lambda u, v, e, prev_e: (
        #         cls.getAverageEdgeDuration(e["edgeID"]) -
        #         (tick - (cls.edgeMap[e["edgeID"]].lastDurationUpdateTick)) 平均到达时间 减去 信息的衰老程度
        #     )
        # else:
        # 2) Advanced cost function that combines duration with averaging
        # isVictim = ??? random x percent (how many % routes have been victomized before)
        isVictim = cls.explorationPercentage > random()
        if isVictim:
            victimizationChoice = 1
        else:
            victimizationChoice = 0

        """
        ``cost_func``
        A function to apply to each edge to modify its base cost.
        The arguments it will be passed are the current node, 
        a neighbor of the current node, the edge that connects the current node to that neighbor, 
        and the edge that was previously traversed to reach the current node.
        应用于每条边以修改其基本代价的函数。
        传递给它的参数包括当前节点、当前节点的邻居、连接当前节点和邻居的边，以及到达当前节点之前经过的边。
        """

        route = nx.astar_path(
            cls.nx_graph_edge, fr, to,
            heuristic=lambda u, v: cls.heuristic(u, v, car, is_RL),
            weight=lambda u, v, d: cost_data[d['from_edgeID']]
        )

        return RouterResult(route, isVictim)

    @classmethod
    def heuristic(cls, u, v, car, is_RL):
        """
                Q优化启发式函数
                h_new(n) = h_base(n) + λ * (-max_a Q(n, a))
                """
        if not is_RL:
            return 0
        else:
            # print(u,v)
            recentCar_pos = traci.vehicle.getPosition(car.id)  # 当前车辆经纬度
            # print(recentCar_pos)
            # 使用 sumolib 获取目标路段坐标
            try:
                target_edge = cls.net.getEdge(v)
                # 获取路段终点坐标
                target_pos = target_edge.getToNode().getCoord()
                # print(target_pos)
            except Exception as e:
                print(f"获取路段{v}的经纬度失败: {e}")
                target_pos = (0, 0)

            # 基础启发式（欧几里得距离）
            h_base = ((recentCar_pos[0] - target_pos[0]) ** 2 + (recentCar_pos[1] - target_pos[1]) ** 2) ** 0.5
            # print(h_base)
            # 获取当前状态
            # print(car.currentEdgeID)
            queue_length = traci.edge.getLastStepVehicleNumber(u)
            current_state = [u, queue_length]   # 当前状态
            # print(current_state)
            # 获取dqn最大Q值
            max_q = 0
            if cls.dqn_model is not None and cls.dqn_edge_ids is not None:
                try:
                    s_idx = cls.dqn_edge_ids.index(u)
                    obs = np.array([[s_idx, queue_length]], dtype=np.float32)
                    acts = np.arange(len(cls.dqn_edge_ids)).reshape(-1, 1).astype(np.float32)
                    obs_batch = np.repeat(obs, len(cls.dqn_edge_ids), axis=0)
                    q_values = cls.dqn_model[car.id].predict_value(obs_batch, acts)
                    max_q = float(np.max(q_values))
                except Exception as e:
                    print(f"Error in heuristic: {e}")
                    traceback.print_exc()
                    max_q = 0  # 若出错则默认0

            # 计算Q惩罚项
            q_penalty = -max_q
            # 组合新启发式-越低越好
            h_new = h_base + 0.3 * q_penalty

            return h_new
        # return 0  # 或者根据实际情况自定义

    @classmethod
    def getFreshness(cls, edgeID, tick):
        try:
            # 计算现在的tick-上一次更新时的tick的差 来计算多久没更新了
            # lastDurationUpdateTick: 上一次更新时间
            lastUpdate = float(tick) - cls.edgeMap[edgeID].lastDurationUpdateTick
            # 如果超过500步就不新鲜了 相除的分母值会越大 最终的新鲜度的值会越小
            # freshnessCutOffValue = 500
            return 1 - min(1, max(0, lastUpdate / cls.freshnessCutOffValue))
        except TypeError as e:
            # print("error in getFreshnessFactor" + str(e))
            return 1

    @classmethod
    def getAverageEdgeDuration(cls, edgeID):
        """ returns the average duration for this edge in the simulation """
        try:
            print(f"=================={edgeID}===========================")
            return cls.edgeMap[edgeID].averageDuration
        except:
            print("error in getAverageEdgeDuration")
            return 1

    @classmethod
    def applyEdgeDurationToAverage(cls, edge, duration, tick):
        """ tries to calculate how long it will take for a single edge
            试着计算一条边需要多长时间
        """

        try:
            cls.edgeMap[edge].applyEdgeDurationToAverage(duration, tick)
        except:
            return 1
