# 导入所需的库
import sys
import traceback
from datetime import datetime

import d3rlpy
import ray
from d3rlpy.algos import DQNConfig
from d3rlpy.metrics import TDErrorEvaluator
from torch_geometric.nn import GCNConv, TopKPooling

from app import Boot
from app.routing.CustomRouter import CustomRouter
from app.simulation.Simulation import Simulation
import csv
from app.network.Network import Network
import sumolib
from torch_geometric.data import Data
import traci
from app.datacollection.Datacollection import DataCollection
import os  # 用于文件和路径操作
import torch.nn.functional as F

# 导入深度学习相关的库
import torch  # PyTorch深度学习框架
import torch.nn as nn  # 神经网络模块
import torch.optim as optim  # 优化器
import pandas as pd
from tqdm import tqdm
import pickle
import numpy as np

from data_parallel import BalancedDataParallel

# 设置环境变量，解决MacOS上的库冲突问题
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'


class CustomDDPMScheduler:
    def __init__(self, num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule="linear"):
        """
        初始化DDPM噪声调度器
        Args:
            num_train_timesteps: 总时间步数T
            beta_start: β1
            beta_end: βT
            beta_schedule: β序列的调度方式
        """
        self.num_train_timesteps = num_train_timesteps
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # 计算β序列 [β1,...,βT]
        if beta_schedule == "linear":
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps).to(self.device)
        else:
            raise NotImplementedError(f"{beta_schedule} schedule is not implemented")

        # 计算α序列 [α1,...,αT] where αt = 1 - βt
        self.alphas = 1. - self.betas

        # 计算ᾱt序列 [ᾱ1,...,ᾱT] where ᾱt = Πt(1-βt)
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

        # 计算ᾱt-1序列 [1,ᾱ1,...,ᾱT-1]
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)

        # 计算√ᾱt序列
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)

        # 计算√(1-ᾱt)序列
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)

        # 计算后验方差序列 β̃t = (1-ᾱt-1)/(1-ᾱt) * βt
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)

    def add_noise(self, x_0, noise, timesteps):
        """
        前向过程: q(xt|x0) = √ᾱt * x0 + √(1-ᾱt) * ε
        Args:
            x_0: 原始样本
            noise: 噪声ε ~ N(0,I)
            timesteps: 时间步t
        """
        timesteps = timesteps.to(self.device)
        x_0 = x_0.to(self.device)
        noise = noise.to(self.device)

        # 获取√ᾱt和√(1-ᾱt)
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[timesteps].view(-1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[timesteps].view(-1, 1, 1)

        # 计算q(xt|x0)
        x_t = sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise

        return x_t

    def step(self, model_output, timestep, sample):
        """
        反向过程: p(xt-1|xt) = N(μt,σt²I)
        where μt = 1/√αt * (xt - βt/√(1-ᾱt) * εθ(xt,t))
        Args:
            model_output: 模型预测的噪声εθ
            timestep: 当前时间步t
            sample: 当前样本xt
        """
        t = timestep

        # 将所有张量移到GPU
        model_output = model_output.to(self.device)
        sample = sample.to(self.device)

        # 获取αt和ᾱt
        alpha_t = self.alphas[t]
        alpha_prod_t = self.alphas_cumprod[t]

        # 获取βt
        beta_t = self.betas[t]

        # 计算预测分布的均值
        # μt = 1/√αt * (xt - βt/√(1-ᾱt) * εθ)
        pred_original_sample = (1 / torch.sqrt(alpha_t)) * (
                sample - (beta_t / torch.sqrt(1 - alpha_prod_t)) * model_output
        )

        # 计算后验方差
        # σt² = β̃t = (1-ᾱt-1)/(1-ᾱt) * βt
        variance = 0. if t == 0 else self.posterior_variance[t]

        # 采样
        if t > 0:
            noise = torch.randn_like(model_output).to(self.device)
            # xt-1 ~ N(μt,σt²I)
            pred_prev_sample = (
                    torch.sqrt(self.alphas_cumprod_prev[t]) * pred_original_sample +
                    torch.sqrt(variance) * noise
            )
        else:
            pred_prev_sample = pred_original_sample

        return type('StepOutput', (), {'prev_sample': pred_prev_sample})()


class TrafficEnvironment:
    """
    0时间 1是否有人行道 2是否有自行车道 3是否有车道、
    4车道数 5路段上的车辆数（拥堵情况） 6车辆平均行驶速度 7静止的车辆数量 8路段长度 9当前旅行时间 10上一时间步的车道占用率 11 权重
    """

    def __init__(self, weight_tick, end_tick):
        # self.Edges_num = 10
        self.Edges_num = 10
        self.end_tick = end_tick  # 50秒后停止，返回done = True，进行下一轮
        self.weight_tick = weight_tick
        self.count = 0  # 计数器
        self.train_data = pd.read_csv('vehicle_counts_test.csv')  # 真实拥堵情况Q
        try:
            self.weight_data = pd.read_csv('weight_data.csv', header=None)
        except:
            pass

    def reset(self, ):
        """
        1、启动sumo环境
        2、tick=0时，利用随即权重指导车辆启动，获得第一步的graph_state_init，作为初始状态

        return:graph_state_init 初始状态
        """
        processID = 0
        parallelMode = False
        useGUI = False

        if processID is not None:
            # Starting the application
            Boot.start(processID, parallelMode, useGUI)

    def init_step(self, simulation_instance, edge_index):
        traci.simulationStep()
        vehicle_data = DataCollection.get_orig_data(simulation_instance.tick, simulation_instance.edge_ids)
        state_init = []
        Q_init = []
        for key, values in vehicle_data.items():
            # state_init.append(values[1:5] + values[6:11])
            # state_init.append(values[4:11])
            Q_init.append(values[7])
        # state_init = torch.tensor(state_init, dtype=torch.float32, requires_grad=True)
        # graph_state_init = Data(x=state_init, edge_index=edge_index)

        # actual_congestion = torch.tensor(self.train_data.iloc[(self.Edges_num * self.count):(self.Edges_num * self.count) + self.Edges_num, 2].values, dtype=torch.float32, requires_grad=True)
        Q_init = torch.tensor(Q_init, dtype=torch.float32, requires_grad=True)
        # calculate_reward(actual_congestion, Q_init)
        self.count += 1
        return Q_init

    def weight_step(self, simulation_instance, edge_index):
        print("0 -> weight")
        action = torch.tensor(self.weight_data.iloc[(self.count - self.weight_tick), :].values,
                              dtype=torch.float32,
                              requires_grad=True)
        print(action)
        vehicle_data = simulation_instance.loop(action)
        state_init = []
        Q_init = []
        for key, values in vehicle_data.items():
            # state_init.append(values[1:5] + values[6:11])
            state_init.append(values[4:11])
            Q_init.append(values[7])
        state_init = torch.tensor(state_init, dtype=torch.float32, requires_grad=True)
        graph_state_init = Data(x=state_init, edge_index=edge_index)

        actual_congestion = torch.tensor(
            self.train_data.iloc[(self.Edges_num * self.count):(self.Edges_num * self.count) + self.Edges_num,
            2].values, dtype=torch.float32, requires_grad=True)
        Q_init = torch.tensor(Q_init, dtype=torch.float32, requires_grad=True)
        # calculate_reward(actual_congestion, Q_init)

        self.count += 1
        return graph_state_init

    def game_step(self, simulation_instance, edge_index, action, car_id_map, is_RL):
        """
        1、执行动作a(t)，获得第 t+1 步的状态、拥堵情况
        2、比较实际拥堵情况与仿真得到的拥堵情况的差距，求得reward（差距越大reward越小）

        return：next_graph_state, 下一步的状态
                reward, 回报
                done, 本次情节（大循环）是否结束
        """
        # 应用动作并更新环境状态
        actual_congestion = torch.tensor(
            self.train_data.iloc[(self.Edges_num * self.count):(self.Edges_num * self.count) + self.Edges_num,
            2].values,
            dtype=torch.float32,
            requires_grad=True)

        # print("car_id_map", car_id_map)
        vehicle_data = simulation_instance.loop(action, car_id_map, is_RL)
        # 仿真得到的拥堵情况
        Q_prime_t_plus_1 = []
        # 状态
        next_state = []
        for key, values in vehicle_data.items():
            # next_state.append(values[1:5] + values[6:11])
            next_state.append(values[4:11])
            Q_prime_t_plus_1.append(values[7])
        next_state = torch.tensor(next_state, dtype=torch.float32, requires_grad=True)
        next_graph_state = Data(x=next_state, edge_index=edge_index)
        Q_prime_t_plus_1 = torch.tensor(Q_prime_t_plus_1, dtype=torch.float32, requires_grad=True)
        # difference = torch.abs(actual_congestion - Q_prime_t_plus_1)
        done = False
        self.count += 1
        if self.count + 1 > self.end_tick:
            self.count = 0
            done = True

        return next_graph_state, done, Q_prime_t_plus_1, actual_congestion

    @staticmethod
    def finish():
        """
        本次循环结束，关闭traci
        """
        Boot.end()


class GraphEncoder(nn.Module):
    def __init__(self, num_road, num_car, latent_dim_road, latent_dim_car, input_dim=1, hidden_dim=32, output_dim=1):
        super().__init__()
        # 第一阶段：处理每个图的节点特征 (114 -> 8)
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.pool1 = TopKPooling(hidden_dim, ratio=latent_dim_road / num_road)  # 114 nodes -> 8 nodes

        # 第二阶段：处理图结构 (1557 graphs -> 30 graphs)
        self.conv2 = GCNConv(hidden_dim, hidden_dim * 2)
        self.pool2 = TopKPooling(hidden_dim * 2, ratio=latent_dim_car / num_car)  # 1557 graphs -> 30 graphs

        # 最终维度调整
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        self.num_car = num_car
        self.num_road = num_road
        self.latent_dim_road = latent_dim_road
        self.latent_dim_car = latent_dim_car
        # 提前缓存图间全连接关系
        self.register_buffer('cached_graph_edges', self.create_graph_edges(num_car))

    def forward(self, x, edge_index):
        """
        输入维度说明：
        x: [batch_size, num_nodes, num_graphs]
        edge_index: 原始图的连接关系 [2, num_edges]
        输出维度：[batch_size, 30, 8]
        """
        x = x.transpose(1, 2)
        x = x.unsqueeze(-1)
        # print("1", x.shape)
        batch_size, num_graphs, num_nodes, features = x.size()

        # === 第一阶段：处理单个图内部结构 ===
        # 合并维度：[batch, graphs, nodes, features] -> [batch*graphs*nodes, features]
        x = x.reshape(-1, features)  # 展平为 [batch*graphs*nodes, features]
        # print("2", x.shape)

        # 生成批量edge_index，每个图的边偏移对应的节点索引
        edge_list = []
        for i in range(batch_size * num_graphs):
            offset = i * num_nodes
            # 对原始edge_index添加偏移量
            edges = edge_index.clone() + offset
            edge_list.append(edges)

        # 合并所有图的边索引
        batch_edge_index = torch.cat(edge_list, dim=1).to(x.device)

        # 验证索引范围
        max_index = batch_edge_index.max().item()
        expected_nodes = batch_size * num_graphs * num_nodes
        if max_index >= expected_nodes:
            raise ValueError(f"Edge index {max_index} exceeds total nodes {expected_nodes}")

        # 图卷积1
        x = F.relu(self.conv1(x, batch_edge_index))  # [batch*graphs*nodes, hidden_dim]
        # print("3", x.shape)

        # 池化到8个节点
        x, edge_index, _, batch, _, _ = self.pool1(
            x, batch_edge_index,
            batch=torch.arange(batch_size * num_graphs, device=x.device).repeat_interleave(num_nodes)
        )
        # print("4", x.shape)

        # === 重组维度以处理图间关系 ===
        # x形状: [batch*graphs*5, hidden_dim]
        x = x.view(batch_size, num_graphs, self.latent_dim_road, -1)  # [batch, graphs, 8, hidden_dim]
        # print("5", x.shape)

        # === 第二阶段：处理图间关系 ===
        # 创建图间全连接关系
        graph_edge_index = self.cached_graph_edges.to(x.device)

        # 扩展连接关系适应批量
        graph_edges = []
        for i in range(batch_size):
            offset = i * num_graphs
            graph_edges.append(graph_edge_index + offset)
        batch_graph_edges = torch.cat(graph_edges, dim=1)

        # 合并维度：[batch, graphs, 8, hidden_dim] -> [batch*graphs*8, hidden_dim]
        x = x.view(-1, x.size(-1))
        # print("6", x.shape)

        # 图卷积2
        x = F.relu(self.conv2(x, batch_graph_edges))  # [batch*graphs*8, hidden_dim*2]
        # print("7", x.shape)

        # 池化到30个图
        x, _, _, _, _, _ = self.pool2(
            x, batch_graph_edges,
            batch=torch.arange(batch_size, device=x.device).repeat_interleave(num_graphs * self.latent_dim_road)
        )
        # print("8", x.shape)
        expected_nodes = batch_size * self.latent_dim_car * self.latent_dim_road
        if x.size(0) != expected_nodes:
            # 动态调整最后一个批次
            x = x[:expected_nodes]

        # === 最终维度调整 ===
        x = self.fc(x)  # [batch*30*8, output_dim]
        # print("9", x.shape)
        x = x.view(batch_size, self.latent_dim_car, self.latent_dim_road)  # [batch, 30, 8]
        # print("10", x.shape)
        # torch.cuda.empty_cache()
        return x

    @staticmethod
    def create_graph_edges(num_graphs):
        edges = []
        for i in range(num_graphs):
            for j in range(num_graphs):
                if i != j:
                    edges.append([i, j])
        return torch.tensor(edges, dtype=torch.long).t().contiguous()


class GraphDecoder(nn.Module):
    def __init__(self, num_road, num_car, latent_dim_road, latent_dim_car, input_dim=1, hidden_dim=32, output_dim=1):
        super().__init__()
        # 第一阶段：处理图间关系 (30图 -> 1557图)
        self.graph_expand = nn.Linear(latent_dim_car, num_car)  # 关键修正：输入维度对齐
        self.deconv1 = GCNConv(input_dim, hidden_dim * 2)

        # 第二阶段：处理节点扩展 (8节点 -> 114节点)
        self.node_expand = nn.Linear(latent_dim_road, num_road)
        self.deconv2 = GCNConv(hidden_dim * 2, hidden_dim)

        # 最终输出调整
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim)
        )
        self.num_car = num_car
        self.num_road = num_road
        self.latent_dim_road = latent_dim_road
        self.latent_dim_car = latent_dim_car
        # 提前缓存图间全连接关系
        self.register_buffer('cached_graph_edges', self.create_graph_edges(num_car))

    def forward(self, x, edge_index):
        """
        输入维度: [batch_size, 30, 8]
        输出维度: [batch_size, 114, 1157]
        """
        batch_size = x.size(0)

        # === 添加特征维度 ===
        x = x.unsqueeze(-1)  # [batch,30,8,1]
        # print("1",x.shape)

        # === 第一阶段：图间扩展 ===
        # 调整维度顺序 [batch,30,8,1] -> [batch,1,8,30]
        x = x.transpose(1, 3)  # 交换维度1和3
        # print("2", x.shape)

        # 展平并应用图扩展 [batch,1,8,30] -> [batch*1*8, 30]
        x = x.reshape(-1, self.latent_dim_car)  # [batch*1*8, 30]
        # print("3", x.shape)

        # 使用线性层扩展图数量 30->1557
        x = self.graph_expand(x)  # [batch*1*8,1557]
        # print("4", x.shape)

        # 恢复维度 [batch*1*8,1557] -> [batch,1557,8,1]
        x = x.view(batch_size, 1, self.latent_dim_road, self.num_car).transpose(1, 3)
        # print("5", x.shape)

        # 合并维度为图卷积准备 [batch*1557*8,1]
        x = x.reshape(-1, 1)
        # print("6", x.shape)

        # 创建全连接图关系
        graph_edge_index = self.cached_graph_edges.to(x.device)

        # 处理批量维度
        edge_list = []
        for i in range(batch_size):
            offset = i * self.num_car
            edge_list.append(graph_edge_index + offset)
        batch_edges = torch.cat(edge_list, dim=1)

        # 图反卷积处理
        x = F.relu(self.deconv1(x, batch_edges))  # [batch*1557*8, hidden_dim*2]
        # print("7", x.shape)

        # === 第二阶段：节点扩展 ===
        # 调整维度 [batch*1557*8, hidden_dim*2] -> [batch,1557,8, hidden_dim*2]
        x = x.view(batch_size, self.num_car, self.latent_dim_road, -1)
        # print("8", x.shape)

        # 扩展节点数 8->114 [batch,1557,8, hidden_dim*2] -> [batch,1557,114, hidden_dim*2]
        x = x.transpose(2, 3)
        x = self.node_expand(x)  # [batch,1557,114, hidden_dim*2]
        # print("9", x.shape)
        x = x.transpose(2, 3)
        # print("10", x.shape)
        x = x.reshape(-1, 32 * 2)  # [batch*1557*114, hidden_dim*2]
        # print("11", x.shape)

        # 创建节点连接关系
        node_edge_index = edge_index.repeat(1, batch_size * self.num_car) + \
                          torch.arange(batch_size * self.num_car, device=x.device).repeat_interleave(
                              edge_index.size(1)) * self.num_road

        # 最终图卷积处理
        x = F.relu(self.deconv2(x, node_edge_index))  # [batch*1557*114, hidden_dim]
        # print("12", x.shape)

        # === 输出调整 ===
        x = self.fc(x)  # [batch*1557*114, 1]
        # print("13", x.shape)
        x = x.view(batch_size, self.num_car, self.num_road)  # [batch,1557,144]
        # print("14", x.shape)
        # torch.cuda.empty_cache()
        return x.transpose(1, 2)  # [batch,114,1557]

    @staticmethod
    def create_graph_edges(num_graphs):
        edges = []
        for i in range(num_graphs):
            for j in range(num_graphs):
                if i != j:
                    edges.append([i, j])
        return torch.tensor(edges, dtype=torch.long).t().contiguous()


# 定义图自编码器模型
class Autoencoder(nn.Module):
    def __init__(self, edge_index, num_road, num_car, latent_dim_road, latent_dim_car):
        super().__init__()
        self.encoder = GraphEncoder(num_road, num_car, latent_dim_road, latent_dim_car)
        self.decoder = GraphDecoder(num_road, num_car, latent_dim_road, latent_dim_car)
        self.register_buffer('edge_index', edge_index.clone().detach())  # 注册为持久缓冲区

    def forward(self, x):
        z = self.encoder(x, self.edge_index)
        x_recon = self.decoder(z, self.edge_index)
        return z, x_recon

    def encode(self, x):
        return self.encoder(x, self.edge_index)

    def decode(self, z):
        return self.decoder(z, self.edge_index)


# 定义Y自编码器模型
class YAutoencoder(nn.Module):
    def __init__(self, num_road, latent_dim_road):
        super(YAutoencoder, self).__init__()

        # 编码器部分
        self.encoder = nn.Sequential(
            nn.Linear(num_road, 1024),
            nn.ReLU(),
            nn.Linear(1024, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim_road)
        )

        # 解码器部分
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim_road, 128),
            nn.ReLU(),
            nn.Linear(128, 1024),
            nn.ReLU(),
            nn.Linear(1024, num_road)
        )

    def forward(self, x):
        # 编码
        encoded = self.encoder(x)
        # 解码
        decoded = self.decoder(encoded)
        return encoded, decoded

    def encode(self, x):
        return self.encoder(x)

    def decode(self, x):
        return self.decoder(x)


class DiffusionModel(nn.Module):
    def __init__(self, latent_dim_car, condition_dim, hidden_dim):
        super(DiffusionModel, self).__init__()
        # 定义三层全连接网络，注意这里的input_dim现在是3（降维后的维度）
        self.fc1 = nn.Linear(latent_dim_car + condition_dim + 1, 256)  # 第一层，加1是为了时间步
        self.fc2 = nn.Linear(256, 1280)  # 第二层
        self.fc3 = nn.Linear(1280, 256)
        self.fc4 = nn.Linear(256, latent_dim_car)  # 输出层

        # 全局平均池化：[N, 114, 900] → [N, 114, 1]
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        # 线性层：[N, 114, 1] → [N, 8, 1]
        self.linear = nn.Linear(114, 8)


    def forward(self, x, condition, timesteps):
        # 前向传播函数
        # print(x.shape,condition.shape,timesteps.shape)  # torch.Size([2, 30, 8]) torch.Size([2, 8]) torch.Size([2])
        x = x.transpose(1, 2)
        condition = condition.transpose(1, 2)
        condition = self.global_pool(condition)  # [N, 114, 1]
        condition = condition.permute(0, 2, 1) # [N, 1, 114]
        condition = self.linear(condition) # [N, 1, 8]
        condition = condition.permute(0, 2, 1) # [N, 8, 1]
        timesteps = timesteps.unsqueeze(-1).unsqueeze(-1).repeat(1, condition.shape[1], 1)
        # print(x.shape, condition.shape, timesteps.shape)
        # torch.Size([2, 8, 30]) torch.Size([2, 8, 1]) torch.Size([2, 8, 1]) -> 2*8*32

        x = torch.cat([x, condition, timesteps], dim=2)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x.transpose(1, 2)


# 训练自编码器的函数
def train_autoencoder(autoencoder, dataset_all, num_car, device, batch_size, num_epochs):
    """
    使用dataset_all训练自编码器
    Args:
        autoencoder: 自编码器模型
        dataset_all: 代理数据集 [N, 114, 1558]
        num_epochs: 训练轮数
        device: 设备
        batch_size: 批次大小
    """

    # 准备数据
    x = dataset_all[:, :, :num_car]  # [N, 114, 1557]
    x = x.float()
    dataset = torch.utils.data.TensorDataset(x)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    optimizer = optim.Adam(autoencoder.parameters(), lr=0.001)
    criterion = nn.MSELoss()

    for epoch in range(num_epochs):
        total_loss = 0
        for batch_x, in dataloader:
            batch_x = batch_x.to(device)
            optimizer.zero_grad()

            # 前向传播
            _, decoded = autoencoder(batch_x)

            # 计算损失
            loss = criterion(decoded, batch_x)

            # 反向传播
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        if (epoch + 1) % 5 == 0:
            print(f"Graph Autoencoder Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(dataloader):.4f}")

    # 保存最终模型
    # torch.save(autoencoder.module.state_dict(), 'graph_autoencoder1401.pth')
    torch.save(autoencoder.state_dict(), 'graph_autoencoder1401.pth')
    return autoencoder


def train_Yautoencoder(Yautoencoder, dataset_all, num_car, device, batch_size, num_epochs):
    """
    使用dataset_all训练Y自编码器
    Args:
        Yautoencoder: Y自编码器模型
        dataset_all: 代理数据集 [N, 114, 1558]
        num_epochs: 训练轮数
        device: 设备
        batch_size: 批次大小
    """
    # 准备数据
    y = dataset_all[:, :, num_car].float()  # [N, 114]
    dataset = torch.utils.data.TensorDataset(y)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    optimizer = optim.Adam(Yautoencoder.parameters(), lr=0.001)
    criterion = nn.MSELoss()

    for epoch in range(num_epochs):
        total_loss = 0
        optimizer.zero_grad()  # 在epoch开始时清零梯度

        for batch_y, in dataloader:
            batch_y = batch_y.to(device)

            # 前向传播
            optimizer.zero_grad()
            _, decoded = Yautoencoder(batch_y)

            # 计算损失
            loss = criterion(decoded, batch_y)

            # 反向传播
            loss.backward(retain_graph=True)  # 保留计算图
            optimizer.step()

            total_loss += loss.item()

        if (epoch + 1) % 5 == 0:
            print(f"YAutoencoder Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(dataloader):.4f}")

    # 保存最终模型
    # torch.save(Yautoencoder.module.state_dict(), 'Yautoencoder_model1401.pth')
    torch.save(Yautoencoder.state_dict(), 'Yautoencoder_model1401.pth')
    return Yautoencoder


# 修改训练扩散模型的函数
def train_diffusion_model(model, dataloader, optimizer, device, num_epochs):
    losses = []

    noise_scheduler = CustomDDPMScheduler(
        num_train_timesteps=1000,
        beta_start=0.0001,
        beta_end=0.02,
        beta_schedule="linear"
    )

    model.train()
    for epoch in range(num_epochs):
        epoch_losses = []
        for batch_x, batch_c, batch_s, mse_value in dataloader:
            # 克隆张量以避免原地操作
            # batch_S = batch_S.clone().detach().squeeze(-1).to(device)
            batch_x = batch_x.clone().detach().to(device)
            # batch_c = batch_c.clone().detach().squeeze(-1).to(device)
            batch_s = batch_s.clone().detach().to(device)
            mse_value = mse_value.clone().detach().squeeze(-1).to(device)
            # 随机选择batch中的样本进行置零
            # mask = torch.rand(batch_x.shape[0]) < 0.1  # 10%的概率将样本置零
            # batch_x[mask] = 0
            # print(batch_c.shape)
            batch_x.requires_grad_()
            # batch_c.requires_grad_()
            batch_s.requires_grad_()

            timesteps = torch.randint(0, 1000, (batch_x.shape[0],), device=device)
            noise = torch.randn_like(batch_x).to(device)
            # print(encoded_x.shape,noise.shape,timesteps.shape)
            # 添加噪声
            noisy_x = noise_scheduler.add_noise(batch_x, noise, timesteps)
            # print(noisy_x.shape, encoded_c.shape, timesteps.shape) # torch.Size([2, 10, 3]) torch.Size([2, 3]) torch.Size([2])
            # 预测噪声
            pred_noise = model(noisy_x, batch_s, timesteps)
            # print(pred_noise.shape,noisy_x.shape)
            mse_value = mse_value.unsqueeze(-1).unsqueeze(-1)
            # 计算损失
            # print(pred_noise.shape, noise.shape, mse_value.shape)   [N, 30, 8] [N, 30, 8] [N, 1, 1]
            loss = F.mse_loss(pred_noise * mse_value, noise * mse_value)
            # 优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_losses.append(loss.item())

        avg_loss = sum(epoch_losses) / len(epoch_losses)
        losses.append(avg_loss)

        if (epoch + 1) % 10 == 0:
            print(f"Diffusion Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")

    # 保存模型
    # torch.save(model.module.state_dict(), 'DiffusionModel1401.pth')
    torch.save(model.state_dict(), 'DiffusionModel1401.pth')
    return model


def generate_new_samples(model, autoencoder, Yautoencoder, latent_dim_road, latent_dim_car, new_condition, num_samples,
                         device):
    noise_scheduler = CustomDDPMScheduler(
        num_train_timesteps=1000,
        beta_start=0.0001,
        beta_end=0.02,
        beta_schedule="linear"
    )

    model.eval()
    autoencoder.eval()
    Yautoencoder.eval()

    # encoded_condition = Yautoencoder.module.encode(new_condition)
    # 生成初始随机噪声（在低维空间）
    x = torch.randn(num_samples, latent_dim_car, latent_dim_road).to(device)
    with torch.no_grad():
        for t in range(1000 - 1, -1, -1):
            timesteps = torch.full((num_samples,), t, device=device)
            # print(x.shape,encoded_condition.unsqueeze(0).shape,timesteps.shape)
            pred_noise = model(x, new_condition, timesteps)
            # 生成无条件预测
            # uncond_encoded = torch.zeros_like(encoded_condition).to(device)
            # uncond_pred_noise = model(x, uncond_encoded.unsqueeze(0), timesteps)

            # 混合预测结果
            # guidance_scale = 8  # 引导尺度,可调整
            # pred_noise = uncond_pred_noise + guidance_scale * (pred_noise - uncond_pred_noise)
            x = noise_scheduler.step(pred_noise, t, x).prev_sample

    # 使用解码器将生成的样本转换回原始维度
    # decoded_samples = autoencoder.module.decode(x)
    decoded_samples = autoencoder.decode(x)
    return decoded_samples

ray.init()

# 生成初始数据集的函数
def generate_dataset(edge_index, car_id_map, new_condition, device, num_road, simu_seconds, num_car, num_samples):
    if os.path.exists("./data/dataset_all_queue.pickle"):  # 读取excel文件
        with open('./data/dataset_all_queue.pickle', 'rb') as f:
            # 从文件中读取并反序列化对象
            dataset_all = pickle.load(f)
        xts, cons, s = torch.split(dataset_all, [1, num_car, 900], dim=2)
        x = xts.to(device)
        y = cons.squeeze(-1).to(device)
        s = s.transpose(1, 2).to(device)
    else:
        xts = torch.empty(0, num_road, num_car).round(decimals=4).to(device)  # 0*114*1557
        with open('./data/xts.pickle', 'wb') as f:
            pickle.dump(xts, f)
        cons = torch.empty((0, num_road)).to(device)  # 0*10
        S = torch.empty((0, 900, num_road)).to(device)
        for i in range(num_samples):
            # 初始化仿真环境
            env = TrafficEnvironment(0, simu_seconds)
            # 每次循环都重新初始化仿真类
            simulation_instance = Simulation()
            # 启动sumo环境，并获取初始环境状态state
            env.reset()
            done = False
            weight = None
            # --- 新增：用于判断车辆是否新出现 ---
            prev_vehicle_ids = set()
            all_experiences = {}  # {car_id: [经验, ...], ...}
            edge_ids = simulation_instance.edge_ids  # 路段ID列表
            S_np = None
            car_n = 0
            while not done:
                # 随机权重
                if weight is None:
                    weight = torch.rand(num_road, num_car).round(decimals=4).to(device)
                next_graph_state, done, result, real = env.game_step(simulation_instance, edge_index, weight, car_id_map, is_RL=False)
                # 获取当前仿真中的车辆ID
                curr_vehicle_ids = set(traci.vehicle.getIDList())
                # 找出新出现的车辆
                new_vehicle_ids = curr_vehicle_ids - prev_vehicle_ids
                # 记录每辆新出现车辆的经验
                if S_np is None:
                    S_np = result.cpu().detach().numpy()
                for car_id in new_vehicle_ids:
                    print(f"new_car_id:{car_id}")
                    if car_n >= 900:
                        break
                    car_idx = car_id_map[car_id]
                    car_exps = []
                    for edge_col in range(edge_index.shape[1]):
                        from_idx = edge_index[0, edge_col].item()
                        to_idx = edge_index[1, edge_col].item()
                        from_edge_id = edge_ids[from_idx]
                        to_edge_id = edge_ids[to_idx]
                        # 状态
                        s_prev = (from_edge_id, float(S_np[from_idx]))
                        # 动作
                        a_prev = to_edge_id
                        # 奖励
                        r_prev = None  # 先空着
                        # 下一个状态
                        s_current = (to_edge_id, float(S_np[to_idx]))
                        exp = {
                            's_prev': s_prev,
                            'a_prev': a_prev,
                            'r_prev': r_prev,
                            's_current': s_current
                        }
                        car_exps.append(exp)
                    all_experiences[car_id] = car_exps
                    car_n += 1
                prev_vehicle_ids = curr_vehicle_ids.copy()
            print("car_n:", car_n)
            env.finish()
            # --- 随机生成权重，补齐奖励 ---
            weight = torch.rand(num_road, num_car).round(decimals=4).to(device)
            xts = torch.cat((xts, weight.unsqueeze(0)), dim=0)
            for car_id, car_exps in all_experiences.items():
                car_idx = car_id_map[car_id]
                for exp in car_exps:
                    if exp['a_prev'] in edge_ids:
                        edge_idx = edge_ids.index(exp['a_prev'])
                        exp['r_prev'] = float(weight[edge_idx, car_idx].item())
                    else:
                        exp['r_prev'] = 0.0
            # --- 构建dataset，训练DQN ---
            # 为每辆车单独构建数据集和训练DQN模型
            if not hasattr(CustomRouter, 'dqn_model'):
                CustomRouter.dqn_model = {}
            refs = []
            for car_id, car_exps in all_experiences.items():
                refs.append(train_dqn.remote(car_exps, car_id, edge_ids, CustomRouter.dqn_model))
            results = ray.get(refs)
            for result in results:
                CustomRouter.dqn_model.update(result)
            print("dqns:", CustomRouter.dqn_model)
            CustomRouter.dqn_edge_ids = edge_ids
            # --- 第二次仿真（用RL），收集S ---
            env = TrafficEnvironment(0, simu_seconds)
            simulation_instance = Simulation()
            env.reset()
            done = False
            weight = None
            queue = torch.empty((0, num_road))
            s = torch.empty((0, num_road))
            prev_vehicle_ids = set()
            car_n = 0
            while not done:
                if weight is None:
                    weight = torch.rand(num_road, num_car).round(decimals=4).to(device)
                next_graph_state, done, result, real = env.game_step(simulation_instance, edge_index, weight, car_id_map, is_RL=True)
                queue = torch.cat((queue, result.unsqueeze(0)), dim=0)
                # 记录每辆车新出现时的排队长度
                curr_vehicle_ids = set(traci.vehicle.getIDList())
                new_vehicle_ids = curr_vehicle_ids - prev_vehicle_ids
                for car_id in new_vehicle_ids:
                    if car_n >= 900:
                        break
                    s = torch.cat((s, result.unsqueeze(0)), dim=0)
                    car_n += 1
                prev_vehicle_ids = curr_vehicle_ids.copy()
            env.finish()
            max_queue_values = queue.max(dim=0)[0]
            congestion = max_queue_values.float().to(device)  # 114
            cons = torch.cat([cons, congestion.unsqueeze(0)], dim=0)
            S = torch.cat([S, s.unsqueeze(0).to(device)], dim=0)
            with open('./data/cons.pickle', 'wb') as f:
                pickle.dump(cons, f)
            print(f"i={i}")
            print(f"S:{S.shape}")
        dataset_all = torch.cat([xts.float(), cons.unsqueeze(-1).float(), S.transpose(1, 2).float()], dim=2)
        with open('./data/dataset_all_queue.pickle', 'wb') as f:
            pickle.dump(dataset_all, f)
        x = xts.to(device)
        y = cons.to(device)
        s = S.to(device)
    print("generate_dataset", x.shape, y.shape)  # torch.Size([880, 114, 1557]) torch.Size([880, 114])
    if not os.path.exists("./data/datasetD.pickle"):
        mse_values = torch.mean((y.unsqueeze(1) - new_condition) ** 2, dim=-1)  # 456-1661
        # print(mse_values)
        # print(max(mse_values),min(mse_values))
        # 找出MSE大于20的样本
        valid_indices = torch.where(mse_values > 30)[0]
        # 筛选符合条件的数据
        # filtered_S = S[valid_indices]
        filtered_x = x[valid_indices]  # [N, 10, 11]
        filtered_y = y[valid_indices]  # [N, 10]
        filtered_s = s[valid_indices]
        print("1filtered_x", filtered_x.shape)
        print("1filtered_y", filtered_y.shape)
        datasetD = torch.cat([filtered_x.float(), filtered_y.float().unsqueeze(-1), filtered_s.transpose(1, 2).float()], dim=2)
        with open('./data/datasetD.pickle', 'wb') as f:
            pickle.dump(datasetD, f)
    else:
        with open('./data/datasetD.pickle', 'rb') as f:
            datasetD = pickle.load(f)
    torch.cuda.empty_cache()
    sys.exit(0)
    return datasetD.to(device), dataset_all.to(device)

@ray.remote
def train_dqn(car_exps, car_id, edge_ids, dqn_model):
    # 构建当前车辆的经验数据集
    car_specific_experiences = {car_id: car_exps}
    d3rlpy_dataset = build_d3rlpy_dataset_from_experiences(car_specific_experiences, edge_ids)
    # 训练当前车辆的DQN模型
    dqn = DQNConfig().create(device="cpu")
    dqn.build_with_dataset(d3rlpy_dataset)
    td_error_evaluator = TDErrorEvaluator(episodes=d3rlpy_dataset.episodes)
    dqn.fit(
        d3rlpy_dataset,
        n_steps=10000,
        evaluators={
            'td_error': td_error_evaluator,
        },
    )
    # 为当前车辆存储DQN模型
    dqn_model[car_id] = dqn
    return dqn_model


def edge_indexf(net):
    edge_index = []
    Net = sumolib.net.readNet(net)
    edge_ids = list(map(lambda x: x.getID(), Net.getEdges()))

    edge_ids_dic = dict()
    for i in range(len(edge_ids)):
        edge_ids_dic[edge_ids[i]] = i
    # print(edge_ids_dic)
    for edge in Network.edges:
        outgoing_edges = edge.getOutgoing()
        for outgoing_edge in outgoing_edges:
            # 处理连接关系
            edge_index_sig = [edge_ids_dic[edge.getID()], edge_ids_dic[outgoing_edge.getID()]]
            edge_index.append(edge_index_sig)
    edge_index = torch.tensor(edge_index, dtype=torch.long).t()
    return edge_index


def get_car_id_map():
    import xml.etree.ElementTree as ET
    # 加载 rou.xml 文件
    tree = ET.parse(r"./app/map/ChongQing-114.rou.xml")
    root = tree.getroot()
    # 提取所有车辆的 ID
    vehicle_ids = []
    # 遍历所有元素
    for element in root.iter():
        if element.tag == 'vehicle':  # 找到 vehicle 标签
            vehicle_id = element.get('id')  # 获取 id 属性
            if vehicle_id:
                vehicle_ids.append(vehicle_id)
    car_id_map = {}
    for i, carID in enumerate(tuple(vehicle_ids)):
        car_id_map[carID] = i
    return car_id_map


def generated_weights_fun(model, autoencoder, Yautoencoder, latent_dim_road, latent_dim_car, new_condition, device):
    # 生成扰动的C和对应的权重样本
    num_samples = 1
    # 生成扰动的C
    noise_scale = 0  # 噪声比例
    perturbed_C = []
    for i in range(num_samples):
        # 添加随机噪声
        noise = torch.randn_like(new_condition) * noise_scale
        perturbed_C.append(new_condition + noise)
    perturbed_C = torch.stack(perturbed_C)  # [30, 114]

    # 使用扩散模型生成对应的权重样本
    generated_weights = []
    for i in range(num_samples):
        # 使用扩散模型生成权重
        weight = generate_new_samples(model, autoencoder, Yautoencoder, latent_dim_road, latent_dim_car, perturbed_C[i],
                                      1, device)
        weight = weight.squeeze(0)  # [114, 1557]
        generated_weights.append(weight)
    generated_weights = torch.stack(generated_weights)  # [30, 114, 1557]
    return generated_weights


def load_my_state_dict(model, state_dict):
    model_is_dataparallel = isinstance(model, nn.DataParallel)
    state_dict_has_module = any(k.startswith('module.') for k in state_dict)

    new_state_dict = {}
    for k, v in state_dict.items():
        if model_is_dataparallel and not state_dict_has_module:
            # 如果模型是 DataParallel，但 state_dict 是普通模型的
            new_key = f'module.{k}'
        elif not model_is_dataparallel and state_dict_has_module:
            # 如果模型不是 DataParallel，但 state_dict 是 DataParallel 保存的
            new_key = k.replace('module.', '', 1)  # 只移除第一个 module.
        else:
            new_key = k
        new_state_dict[new_key] = v

    model.load_state_dict(new_state_dict)
    return model


def dimensionality_reduction_for_diffusion_model(Yautoencoder, autoencoder, batch_size_DF, datasetD, device, num_car,
                                                 mapped_values):
    dataset_x2 = datasetD[:, :, :num_car].float()  # [N, 114, 1557]
    dataset_y2 = datasetD[:, :, num_car].float()  # [N, 114]
    dataset_s = datasetD[:, :, num_car+1:].float() # [N, 114, 900]

    # Process in batches to manage memory
    batch_size_encode = 2  # Smaller batch size for encoding

    dataset_x2_encode = []
    dataset_y2_encode = []

    autoencoder.eval()
    Yautoencoder.eval()
    with torch.no_grad():
        for i in range(0, dataset_x2.size(0), batch_size_encode):
            # Encode x batch
            batch_x = dataset_x2[i:i + batch_size_encode].to(device)
            # encoded_x_batch = autoencoder.module.encode(batch_x)  # [batch, 30, 8]
            encoded_x_batch = autoencoder.encode(batch_x)  # [batch, 30, 8]
            dataset_x2_encode.append(encoded_x_batch.cpu())  # Move to CPU

            # Encode y batch
            batch_y = dataset_y2[i:i + batch_size_encode].to(device)
            # encoded_y_batch = Yautoencoder.module.encode(batch_y)  # [batch, 8]
            encoded_y_batch = Yautoencoder.encode(batch_y)  # [batch, 8]
            dataset_y2_encode.append(encoded_y_batch.cpu())  # Move to CPU

    # Concatenate all encoded batches
    dataset_x2_encode = torch.cat(dataset_x2_encode, dim=0).to(device)  # [N, 30, 8]
    dataset_y2_encode = torch.cat(dataset_y2_encode, dim=0).to(device)  # [N, 8]

    # Create dataset and dataloader
    dataset2 = torch.utils.data.TensorDataset(dataset_x2_encode, dataset_y2_encode, dataset_s, mapped_values)
    dataloader = torch.utils.data.DataLoader(dataset2, batch_size=batch_size_DF, shuffle=True)

    # Clean up
    torch.cuda.empty_cache()
    return dataloader


def simulation_batch(car_id_map, device, edge_index, new_condition, simu_seconds, weights):
    """对批量权重进行仿真，返回仿真结果和对应的MSE"""
    congestion_set = []
    mse_list = []

    for weight in weights:
        # 初始化仿真环境
        env = TrafficEnvironment(0, simu_seconds)
        # 每次循环都重新初始化仿真类
        simulation_instance = Simulation()
        # 启动sumo环境，并获取初始环境状态state
        env.reset()
        done = False
        # 存储每条路的排队长度
        queue = torch.empty((0, num_road))
        while not done:
            next_graph_state, done, result, real = env.game_step(simulation_instance, edge_index, weight,
                                                                 car_id_map)
            queue = torch.cat((queue, result.unsqueeze(0)), dim=0)  # 480*114
        env.finish()
        max_queue_values = queue.max(dim=0)[0]
        congestion = max_queue_values.float().to(device)  # 114

        print("congestion", congestion)

        # 计算MSE
        mse = torch.mean((new_condition - congestion) ** 2)

        congestion_set.append(congestion)
        mse_list.append(mse.item())

        # 记录MSE
        with open("./d_loss_list_1401.csv", 'a', newline='', encoding='utf-8') as mycsvfile:
            writer = csv.writer(mycsvfile)
            writer.writerow([mse.item()])

        # 写入条件和仿真结果比较
        write_condition = new_condition.unsqueeze(0).cpu().detach().numpy()
        simu_congestion = congestion.unsqueeze(0).cpu().detach().numpy()
        with open("./diff_1401.csv", 'a', newline='', encoding='utf-8') as mycsvfile:
            writer = csv.writer(mycsvfile)
            writer.writerow(['target-{}s'.format((1) * simu_seconds)] + list(write_condition[0]))
            writer.writerow(['simulation-{}s'.format((1) * simu_seconds)] + list(simu_congestion[0]))

        torch.cuda.empty_cache()  # 释放显存

    return torch.stack(congestion_set), torch.tensor(mse_list, device=device)


def simulation_C(num_samples, simu_seconds, edge_index, car_id_map, new_condition, model, autoencoder, Yautoencoder,
                 num_road, num_car, latent_dim_road, latent_dim_car, device):
    congestion_set = []
    mse_list = []
    generated_weights = []
    S = []
    for i in range(num_samples):
        # 初始化仿真环境
        env = TrafficEnvironment(0, simu_seconds)
        # 每次循环都重新初始化仿真类
        simulation_instance = Simulation()
        # 启动sumo环境，并获取初始环境状态state
        env.reset()
        done = False
        weight = None
        # 存储每条路的排队长度
        queue = torch.empty((0, num_road))
        # --- 新增：用于判断车辆是否新出现 ---
        prev_vehicle_ids = set()
        all_experiences = {}  # {car_id: [经验, ...], ...}
        edge_ids = simulation_instance.edge_ids  # 路段ID列表
        s = torch.empty((0, num_road))
        S_np = None
        car_n = 0
        while not done:
            # 随机权重
            if weight is None:
                weight = torch.rand(num_road, num_car).round(decimals=4).to(device)
            next_graph_state, done, result, real = env.game_step(simulation_instance, edge_index, weight, car_id_map, is_RL=False)
            queue = torch.cat((queue, result.unsqueeze(0)), dim=0)
            # 获取当前仿真中的车辆ID
            curr_vehicle_ids = set(traci.vehicle.getIDList())
            # 找出新出现的车辆
            new_vehicle_ids = curr_vehicle_ids - prev_vehicle_ids
            # 记录每辆新出现车辆的经验
            if S_np is None:
                S_np = result.cpu().detach().numpy()
            for car_id in new_vehicle_ids:
                if car_n >= 900:
                    break
                car_idx = car_id_map[car_id]
                car_exps = []
                for edge_col in range(edge_index.shape[1]):
                    from_idx = edge_index[0, edge_col].item()
                    to_idx = edge_index[1, edge_col].item()
                    from_edge_id = edge_ids[from_idx]
                    to_edge_id = edge_ids[to_idx]
                    # 状态
                    s_prev = (from_edge_id, float(S_np[from_idx]))
                    # 动作
                    a_prev = to_edge_id
                    # 奖励
                    r_prev = None  # 先空着
                    # 下一个状态
                    s_current = (to_edge_id, float(S_np[to_idx]))
                    exp = {
                        's_prev': s_prev,
                        'a_prev': a_prev,
                        'r_prev': r_prev,
                        's_current': s_current
                    }
                    car_exps.append(exp)
                all_experiences[car_id] = car_exps
                s = torch.cat((s, result.unsqueeze(0)), dim=0)
                car_n += 1
            prev_vehicle_ids = curr_vehicle_ids.copy()
        env.finish()
        # --- 随机生成权重，补齐奖励 ---
        # 使用扩散模型生成权重
        weight = generate_new_samples(model, autoencoder, Yautoencoder, latent_dim_road, latent_dim_car,
                                      s, 1, device)
        generated_weights.append(weight)
        for car_id, car_exps in all_experiences.items():
            car_idx = car_id_map[car_id]
            for exp in car_exps:
                if exp['a_prev'] in edge_ids:
                    edge_idx = edge_ids.index(exp['a_prev'])
                    exp['r_prev'] = float(weight[edge_idx, car_idx].item())
                else:
                    exp['r_prev'] = 0.0
        # --- 构建dataset，训练DQN ---
        # 为每辆车单独构建数据集和训练DQN模型
        if not hasattr(CustomRouter, 'dqn_model'):
            CustomRouter.dqn_model = {}
        for car_id, car_exps in all_experiences.items():
            # 构建当前车辆的经验数据集
            car_specific_experiences = {car_id: car_exps}
            d3rlpy_dataset = build_d3rlpy_dataset_from_experiences(car_specific_experiences, edge_ids)

            # 训练当前车辆的DQN模型
            dqn = DQNConfig().create(device="cuda:0")
            dqn.build_with_dataset(d3rlpy_dataset)
            td_error_evaluator = TDErrorEvaluator(episodes=d3rlpy_dataset.episodes)
            dqn.fit(
                d3rlpy_dataset,
                n_steps=10000,
                evaluators={
                    'td_error': td_error_evaluator,
                },
            )
            # 为当前车辆存储DQN模型
            CustomRouter.dqn_model[car_id] = dqn
        CustomRouter.dqn_edge_ids = edge_ids
        # --- 第二次仿真（用RL），收集S ---
        env = TrafficEnvironment(0, simu_seconds)
        simulation_instance = Simulation()
        env.reset()
        done = False
        weight = None
        queue = torch.empty((0, num_road))
        prev_vehicle_ids = set()
        s = torch.empty((0, num_road))
        car_n = 0
        while not done:
            if weight is None:
                weight = torch.rand(num_road, num_car).round(decimals=4).to(device)
            next_graph_state, done, result, real = env.game_step(simulation_instance, edge_index, weight, car_id_map, is_RL=True)
            queue = torch.cat((queue, result.unsqueeze(0)), dim=0)
            # 获取当前仿真中的车辆ID
            curr_vehicle_ids = set(traci.vehicle.getIDList())
            # 找出新出现的车辆
            new_vehicle_ids = curr_vehicle_ids - prev_vehicle_ids
            for car_id in new_vehicle_ids:
                if car_n >= 900:
                    break
                s = torch.cat((s, result.unsqueeze(0)), dim=0)
                car_n += 1
            prev_vehicle_ids = curr_vehicle_ids.copy()
        env.finish()
        S.append(s)
        max_queue_values = queue.max(dim=0)[0]
        congestion = max_queue_values.float().to(device)  # 114

        print("congestion", congestion)

        # 计算MSE
        mse = torch.mean((new_condition - congestion) ** 2)

        congestion_set.append(congestion)
        mse_list.append(mse.item())

        # 记录MSE
        with open("./d_loss_list_1401.csv", 'a', newline='', encoding='utf-8') as mycsvfile:
            writer = csv.writer(mycsvfile)
            writer.writerow([mse.item()])

        # 写入条件和仿真结果比较
        write_condition = new_condition.unsqueeze(0).cpu().detach().numpy()
        simu_congestion = congestion.unsqueeze(0).cpu().detach().numpy()
        with open("./diff_1401.csv", 'a', newline='', encoding='utf-8') as mycsvfile:
            writer = csv.writer(mycsvfile)
            writer.writerow(['target-{}s'.format((1) * simu_seconds)] + list(write_condition[0]))
            writer.writerow(['simulation-{}s'.format((1) * simu_seconds)] + list(simu_congestion[0]))

        torch.cuda.empty_cache()  # 释放显存
    generated_weights = torch.stack(generated_weights)  # [1, 114, 1557]
    S = torch.stack(S)
    return torch.stack(congestion_set), torch.tensor(mse_list, device=device), generated_weights, S

def create_Q(simu_seconds, edge_index, device, num_road, num_car, car_id_map):
    # S = torch.empty((0, num_road)).to(device)

    # 初始化仿真环境
    env = TrafficEnvironment(0, simu_seconds)
    # 每次循环都重新初始化仿真类
    simulation_instance = Simulation()
    # 启动sumo环境，并获取初始环境状态state
    env.reset()
    done = False
    weight = None
    # 存储每条路的排队长度
    # queue = torch.empty((0, 114))
    # s = torch.empty((0, 114))
    # --- 新增：用于判断车辆是否新出现 ---
    prev_vehicle_ids = set()
    all_experiences = {}  # {car_id: [经验, ...], ...}
    edge_ids = simulation_instance.edge_ids  # 路段ID列表
    S_np = None
    car_n = 0
    while not done:
        # 随机权重
        if weight is None:
            weight = torch.rand(num_road, num_car).round(decimals=4).to(device)
        next_graph_state, done, result, real = env.game_step(simulation_instance, edge_index, weight, car_id_map,
                                                             is_RL=False)
        # 获取当前仿真中的车辆ID
        curr_vehicle_ids = set(traci.vehicle.getIDList())
        # 找出新出现的车辆
        new_vehicle_ids = curr_vehicle_ids - prev_vehicle_ids
        # 记录每辆新出现车辆的经验
        if S_np is None:
            S_np = result.cpu().detach().numpy()
        for car_id in new_vehicle_ids:
            if car_n >= 900:
                break
            car_idx = car_id_map[car_id]
            car_exps = []
            for edge_col in range(edge_index.shape[1]):
                from_idx = edge_index[0, edge_col].item()
                to_idx = edge_index[1, edge_col].item()
                from_edge_id = edge_ids[from_idx]
                to_edge_id = edge_ids[to_idx]
                # 状态
                s_prev = (from_edge_id, float(S_np[from_idx]))
                # 动作
                a_prev = to_edge_id
                # 奖励
                r_prev = None  # 先空着
                # 下一个状态
                s_current = (to_edge_id, float(S_np[to_idx]))
                exp = {
                    's_prev': s_prev,
                    'a_prev': a_prev,
                    'r_prev': r_prev,
                    's_current': s_current
                }
                car_exps.append(exp)
            all_experiences[car_id] = car_exps
            car_n += 1
        prev_vehicle_ids = curr_vehicle_ids.copy()
    print("car_n:", car_n)
    env.finish()

    # --- 随机生成权重，补齐奖励 ---
    weight = torch.rand(num_road, num_car).round(decimals=4).to(device)
    for car_id, car_exps in all_experiences.items():
        car_idx = car_id_map[car_id]
        for exp in car_exps:
            if exp['a_prev'] in edge_ids:
                edge_idx = edge_ids.index(exp['a_prev'])
                exp['r_prev'] = float(weight[edge_idx, car_idx].item())
            else:
                exp['r_prev'] = 0.0

    # --- 构建dataset，训练DQN ---
    # 为每辆车单独构建数据集和训练DQN模型
    if not hasattr(CustomRouter, 'dqn_model'):
        CustomRouter.dqn_model = {}
    for car_id, car_exps in all_experiences.items():
        # 构建当前车辆的经验数据集
        car_specific_experiences = {car_id: car_exps}
        d3rlpy_dataset = build_d3rlpy_dataset_from_experiences(car_specific_experiences, edge_ids)

        # 训练当前车辆的DQN模型
        dqn = DQNConfig().create(device="cuda:0")
        dqn.build_with_dataset(d3rlpy_dataset)
        td_error_evaluator = TDErrorEvaluator(episodes=d3rlpy_dataset.episodes)
        dqn.fit(
            d3rlpy_dataset,
            n_steps=10000,
            evaluators={
                'td_error': td_error_evaluator,
            },
        )
        dqn.save(f"./data/dqns/dqn-{car_id}.d3")
        # 为当前车辆存储DQN模型
        CustomRouter.dqn_model[car_id] = dqn
    CustomRouter.dqn_edge_ids = edge_ids

    # --- 第二次仿真（用RL），收集S ---
    env = TrafficEnvironment(0, simu_seconds)
    simulation_instance = Simulation()
    env.reset()
    done = False
    weight = None
    queue = torch.empty((0, num_road))
    # s = torch.empty((0, num_road))
    prev_vehicle_ids = set()
    while not done:
        if weight is None:
            weight = torch.rand(num_road, num_car).round(decimals=4).to(device)
            torch.zeros()
        next_graph_state, done, result, real = env.game_step(simulation_instance, edge_index, weight, car_id_map,
                                                             is_RL=True)
        queue = torch.cat((queue, result.unsqueeze(0)), dim=0)
        # 记录每辆车新出现时的排队长度
        # curr_vehicle_ids = set(traci.vehicle.getIDList())
        # new_vehicle_ids = curr_vehicle_ids - prev_vehicle_ids
        # # for car_id in new_vehicle_ids:
        # #     s = torch.cat((s, result.unsqueeze(0)), dim=0)
        # prev_vehicle_ids = curr_vehicle_ids.copy()
    env.finish()
    max_queue_values = queue.max(dim=0)[0]
    congestion = max_queue_values.float().to(device)  # 114
    print("congestion", congestion)
    with open('./data/Q_queue.pickle', 'wb') as f:
        pickle.dump(congestion, f)


def build_d3rlpy_dataset_from_experiences(all_experiences, edge_ids=None):
    """
    根据自定义经验库（{car_id: [经验, ...]，...}，每条经验为{s_prev, a_prev, r_prev, s_current}）构建d3rlpy数据集。
    """
    import numpy as np
    import d3rlpy
    obs_list = []
    act_list = []
    rew_list = []
    next_obs_list = []
    terminals_list = []

    for car_id, traj in all_experiences.items():
        for i, exp in enumerate(traj):
            s_prev = exp['s_prev']
            a_prev = exp['a_prev']
            r_prev = exp['r_prev']
            s_current = exp['s_current']
            # 需要edge_ids用于ID到索引的映射
            try:
                s_prev_idx = edge_ids.index(s_prev[0])
                a_prev_idx = edge_ids.index(a_prev)
                s_current_idx = edge_ids.index(s_current[0])
            except Exception as e:
                continue
            obs = np.array([s_prev_idx, s_prev[1]], dtype=np.float32)
            act = np.array([a_prev_idx], dtype=np.float32)
            rew = float(r_prev)
            next_obs = np.array([s_current_idx, s_current[1]], dtype=np.float32)
            terminal = 1.0 if i == len(traj) - 1 else 0.0
            obs_list.append(obs)
            act_list.append(act)
            rew_list.append(rew)
            next_obs_list.append(next_obs)
            terminals_list.append(terminal)
    if len(obs_list) == 0:
        raise ValueError('经验库为空，无法构建d3rlpy数据集')
    observations = np.stack(obs_list)
    actions = np.stack(act_list)
    rewards = np.array(rew_list, dtype=np.float32)
    next_observations = np.stack(next_obs_list)
    terminals = np.array(terminals_list, dtype=np.float32)
    dataset = d3rlpy.dataset.MDPDataset(
        observations=observations,
        actions=actions,
        rewards=rewards,
        terminals=terminals,
    )
    return dataset


# 修改主函数
def main():
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = "./app/map/ChongQing-114.net.xml"
    car_id_map = get_car_id_map()
    # 创建拓扑图邻接矩阵 edge_index
    edge_index = edge_indexf(net)
    edge_index = edge_index.to(device)
    # 设置超参数
    num_road = 114  # 原始输入维度
    num_car = 1557
    latent_dim_road = 8  # road降维后的维度
    latent_dim_car = 30  # car降维后的维度
    condition_dim = 1
    hidden_dim = 128
    learning_rate = 0.0001
    num_samples = 1  # 预训练集数量
    simu_seconds = 480  # 模拟秒数
    batch_size_DF = 5  # 扩散模型批次大小
    batch_size_autoencoder = 5
    batch_size_Yautoencoder = 5

    # 设置目标条件值
    if os.path.exists("./data/Q_queue.pickle"):
        with open('./data/Q_queue.pickle', 'rb') as f:
            new_condition = pickle.load(f)
    else:
        print("没有目标文件")
    new_condition = new_condition.clone().detach().float().to(device)
    # print("new_condition", new_condition, new_condition.shape)

    # 随机种子
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)

    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs!")
        autoencoder = Autoencoder(edge_index, num_road, num_car, latent_dim_road, latent_dim_car)
        Yautoencoder = YAutoencoder(num_road, latent_dim_road)
        model = DiffusionModel(latent_dim_car, condition_dim, hidden_dim)
        autoencoder = BalancedDataParallel(2, autoencoder, dim=0)
        Yautoencoder = BalancedDataParallel(2, Yautoencoder, dim=0)
        model = BalancedDataParallel(2, model, dim=0)

        # 创建自编码器和扩散模型
        if os.path.exists("graph_autoencoder1401.pth"):
            state_dict = torch.load("graph_autoencoder1401.pth")
            autoencoder = load_my_state_dict(autoencoder, state_dict)
        # 创建自编码器和扩散模型
        if os.path.exists("Yautoencoder_model1401.pth"):
            state_dict = torch.load("Yautoencoder_model1401.pth")
            Yautoencoder = load_my_state_dict(Yautoencoder, state_dict)
        # 扩散模型
        if os.path.exists("DiffusionModel1401.pth"):
            state_dict = torch.load("DiffusionModel1401.pth")
            model = load_my_state_dict(model, state_dict)

        autoencoder.to(device)
        Yautoencoder.to(device)
        model.to(device)
    else:
        # 创建自编码器和扩散模型
        autoencoder = Autoencoder(edge_index, num_road, num_car, latent_dim_road, latent_dim_car).to(device)
        if os.path.exists("graph_autoencoder1401.pth"):
            autoencoder.load_state_dict(torch.load("graph_autoencoder1401.pth"))
        # 创建自编码器和扩散模型
        Yautoencoder = YAutoencoder(num_road, latent_dim_road).to(device)
        if os.path.exists("Yautoencoder_model1401.pth"):
            Yautoencoder.load_state_dict(torch.load("Yautoencoder_model1401.pth"))
        # 扩散模型
        model = DiffusionModel(latent_dim_car, condition_dim, hidden_dim).to(device)
        if os.path.exists("DiffusionModel1401.pth"):
            model.load_state_dict(torch.load("DiffusionModel1401.pth"))

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # 生成初始数据集
    datasetD, dataset_all = generate_dataset(edge_index, car_id_map, new_condition, device, num_road, simu_seconds,
                                             num_car, num_samples)
    y_D = datasetD[:, :, num_car].float()
    mse_values = torch.mean((y_D.unsqueeze(1) - new_condition) ** 2, dim=-1)
    with open('./data/mse_values1.pickle', 'wb') as f:
        pickle.dump(mse_values.float(), f)
    # 归一化到 0-1 并反转（1 - x），然后缩放到 0-40
    mapped_values = 40 * (1 - (mse_values - mse_values.min()) / (mse_values.max() - mse_values.min()))
    # print("mse_values", mse_values.shape) [880, 1]
    # print("mapped_values", mapped_values.shape) [880, 1]
    # 计算绝对差值和
    # abs_diff_sum = torch.sum(torch.abs(y_D.unsqueeze(1) - new_condition), dim=-1)
    # # 将mse_values和mapped_values保存到CSV文件
    # mse_mapped_df = pd.DataFrame({
    #     'mse_values': mse_values.cpu().detach().numpy().flatten(),
    #     'mapped_values': mapped_values.cpu().detach().numpy().flatten(),
    #     'abs_diff_sum': abs_diff_sum.cpu().detach().numpy().flatten()
    # })
    # mse_mapped_df.to_csv('mse_mapped_values.csv', index=False)

    if ((not os.path.exists("graph_autoencoder1401.pth")) or (not os.path.exists("Yautoencoder_model1401.pth")) or
            (not os.path.exists("DiffusionModel1401.pth"))):
        # 所有数据训练代理模型和编码器
        print("start:", datetime.now())
        autoencoder = train_autoencoder(autoencoder, dataset_all, num_car, device, batch_size=batch_size_autoencoder,
                                        num_epochs=50)
        print("autoencoder:", datetime.now())
        Yautoencoder = train_Yautoencoder(Yautoencoder, dataset_all, num_car, device,
                                          batch_size=batch_size_Yautoencoder,
                                          num_epochs=100)
        print("Yautoencoder:", datetime.now())
        # 高分数据集训练扩散模型
        dataloader = dimensionality_reduction_for_diffusion_model(Yautoencoder, autoencoder, batch_size_DF, datasetD,
                                                                  device, num_car, mapped_values)
        model = train_diffusion_model(model, dataloader, optimizer, device, num_epochs=100)
        print("diffusion_model:", datetime.now())
        torch.cuda.empty_cache()  # 释放显存

    for i in tqdm(range(3000), desc="Optimization Progress"):
        # 生成5个权重
        # generated_weights = generated_weights_fun(model, autoencoder, Yautoencoder, latent_dim_road, latent_dim_car,
        #                                           new_condition, device)
        # # 对所有生成的权重进行仿真
        # all_congestion_set, all_mse = simulation_batch(car_id_map, device, edge_index, new_condition,
        #                                                simu_seconds, generated_weights)
        all_congestion_set, all_mse, generated_weights, all_S = simulation_C(1, simu_seconds, edge_index, car_id_map, new_condition, model, autoencoder, Yautoencoder, num_road, num_car, latent_dim_road, latent_dim_car, device)
        new_lines = torch.cat([generated_weights, all_congestion_set.unsqueeze(-1), all_S.transpose(1, 2).float()], dim=2)
        # 筛选MSE小于20的样本组成new_lines_D
        valid_indices = torch.where(all_mse < 20)[0]
        if len(valid_indices) > 0:
            new_lines_D = new_lines[valid_indices]
            # 添加新样本
            datasetD = torch.cat((datasetD, new_lines_D), dim=0)
            if os.path.exists("./data/mse_values1.pickle"):
                with open('./data/mse_values1.pickle', 'rb') as f:
                    mse_values = pickle.load(f)
            # print(mse_values.shape, all_mse.unsqueeze(-1).shape) [880, 1] [5, 1]
            new_lines_mse = all_mse[valid_indices]
            mse_values = torch.cat((mse_values, new_lines_mse.unsqueeze(-1)), dim=0)
            # 归一化到 0-1 并反转（1 - x），然后缩放到 0-40
            mapped_values = 40 * (1 - (mse_values - mse_values.min()) / (mse_values.max() - mse_values.min()))
            # 保存更新后的数据集
            with open('./data/datasetD.pickle', 'wb') as f:
                pickle.dump(datasetD.float(), f)
            with open('./data/mse_values1.pickle', 'wb') as f:
                pickle.dump(mse_values.float(), f)

        # 将所有新样本添加到dataset_all
        dataset_all = torch.cat((dataset_all, new_lines), dim=0)
        with open('./data/dataset_all_queue.pickle', 'wb') as f:
            pickle.dump(dataset_all.float(), f)

        dataset_x2 = datasetD[:, :, :num_car].float()
        dataset_y2 = datasetD[:, :, num_car].float()
        dataset_s = datasetD[:, :, num_car+1:].float()
        print("dataset_x2", dataset_x2.shape)
        print("dataset_y2", dataset_y2.shape)
        print("dataset_s", dataset_s.shape)
        print("mse", all_mse)
        print("start:", datetime.now())
        # 所有数据训练代理模型和编码器
        autoencoder = train_autoencoder(autoencoder, dataset_all.detach(), num_car, device,
                                        batch_size=batch_size_autoencoder, num_epochs=5)
        print("autoencoder:", datetime.now())
        Yautoencoder = train_Yautoencoder(Yautoencoder, dataset_all.detach(), num_car, device,
                                          batch_size=batch_size_Yautoencoder, num_epochs=10)
        print("Yautoencoder:", datetime.now())
        dataloader = dimensionality_reduction_for_diffusion_model(Yautoencoder, autoencoder, batch_size_DF, datasetD,
                                                                  device, num_car, mapped_values)
        # 高分数据集训练扩散模型
        model = train_diffusion_model(model, dataloader, optimizer, device, num_epochs=10)
        print("diffusion_model:", datetime.now())
        if torch.cuda.memory_reserved() > 0.8 * torch.cuda.get_device_properties(0).total_memory:
            torch.cuda.empty_cache()


if __name__ == "__main__":
    if os.path.exists('DiffusionModel1401.pth'):
        os.remove('DiffusionModel1401.pth')
    if os.path.exists('graph_autoencoder1401.pth'):
        os.remove('graph_autoencoder1401.pth')
    if os.path.exists('Yautoencoder_model1401.pth'):
        os.remove('Yautoencoder_model1401.pth')
    if os.path.exists('d_loss_list_1401.csv'):
        os.remove('d_loss_list_1401.csv')
    if os.path.exists('diff_1401.csv'):
        os.remove('diff_1401.csv')
    if os.path.exists('d_W_1_1401.csv'):
        os.remove('d_W_1_1401.csv')
    main()
