import argparse
import time
import torch
import torch.nn.functional as F
from torch import nn
from torch_geometric.nn.conv import gcn_conv
from torch.utils.data import DataLoader, Dataset
from gnnexp import models
import numpy as np
import os  # 添加os模块导入

# 明确定义device变量
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='mnist_balanced_base_h64_o32_6layer', help='Name of dataset.')
parser.add_argument('--output', type=str, default=None, help='output path.')
parser.add_argument('--lr', type=float, default=0.01, help='Initial learning rate.')
parser.add_argument('-e', '--epoch', type=int, default=300, help='Number of training epochs.')
parser.add_argument('-b', '--batch_size', type=int, default=128, help='Number of samples in a minibatch.')
parser.add_argument('--seed', type=int, default=42, help='Number of training epochs.')
parser.add_argument('--max_grad_norm', type=float, default=1, help='max_grad_norm.')
parser.add_argument('--dropout', type=float, default=0., help='Dropout rate (1 - keep probability).')
parser.add_argument('--encoder_hidden1', type=int, default=32, help='Number of units in hidden layer 1.')
parser.add_argument('--encoder_hidden2', type=int, default=16, help='Number of units in hidden layer 2.')
parser.add_argument('--encoder_output', type=int, default=16, help='Dim of output of VGAE encoder.')
parser.add_argument('--decoder_hidden1', type=int, default=16, help='Number of units in decoder hidden layer 1.')
parser.add_argument('--decoder_hidden2', type=int, default=16, help='Number of units in decoder hidden layer 2.')
parser.add_argument('--K', type=int, default=8, help='Number of casual factors.')
parser.add_argument('--coef_lambda', type=float, default=0.01, help='Coefficient of gae loss.')
parser.add_argument('--coef_kl', type=float, default=0.01, help='Coefficient of gae loss.')
parser.add_argument('--coef_causal', type=float, default=1.0, help='Coefficient of causal loss.')
parser.add_argument('--coef_size', type=float, default=0.0, help='Coefficient of size loss.')
parser.add_argument('--NX', type=int, default=1, help='Number of monte-carlo samples per causal factor.')
parser.add_argument('--NA', type=int, default=1, help='Number of monte-carlo samples per causal factor.')
parser.add_argument('--Nalpha', type=int, default=25, help='Number of monte-carlo samples per causal factor.')
parser.add_argument('--Nbeta', type=int, default=100, help='Number of monte-carlo samples per noncausal factor.')
parser.add_argument('--node_perm', action="store_true",
                    help='Use node permutation as data augmentation for causal training.')
parser.add_argument('--load_ckpt', default=None, help='Load parameters from checkpoint.')
parser.add_argument('--gpu', default=False)
parser.add_argument('--resume', action='store_true')
parser.add_argument('--retrain', action='store_true')
parser.add_argument('--patient', type=int, default=100, help='Patient for early stopping.')
parser.add_argument('--plot_info_flow', action='store_true')

# 数据集配置字典
DATASET_CONFIGS = {
    'BA-2motif': {
        'ckpt_file': 'BA-2motif_base_h20_o20.pth.tar',
        'feature_dim': 10,
        'num_classes': 2,
        'max_nodes': 100
    },
    'BBBP': {
        'ckpt_file': 'BBBP_base_h20_o20.pth.tar',
        'feature_dim': 10,
        'num_classes': 2,
        'max_nodes': 100
    },
    'Mutagenicity': {
        'ckpt_file': 'Mutagenicity_base_h20_o20.pth.tar',
        'feature_dim': 14,
        'num_classes': 2,
        'max_nodes': 100
    },
    'NCI1': {
        'ckpt_file': 'NCI1_base_h20_o20.pth.tar',
        'feature_dim': 37,
        'num_classes': 2,
        'max_nodes': 100
    },
    'ba3_base_h64_o32': {
        'ckpt_file': 'ba3_base_h64_o32.pth.tar',
        'feature_dim': 10,
        'num_classes': 3,
        'max_nodes': 100
    },
    'mnist_balanced_base_h64_o32_6layer': {
        'ckpt_file': 'mnist_balanced_base_h64_o32_6layer.pth.tar',
        'feature_dim': 5,
        'num_classes': 10,
        'max_nodes': 100
    }
}

def get_dataset_config(dataset_name):
    """获取数据集配置"""
    if dataset_name not in DATASET_CONFIGS:
        raise ValueError(f"Unsupported dataset: {dataset_name}. Supported datasets: {list(DATASET_CONFIGS.keys())}")
    return DATASET_CONFIGS[dataset_name]

def edge_importance_bce_loss(edge_probs, edge_indices, preds, pred):
    total_loss = 0.0
    batch_size = len(preds)
    
    for i, (pred_i, edge_idx) in enumerate(zip(preds, edge_indices)):
        # 获取边的概率分数
        edge_scores = edge_probs[edge_idx[0]]
        
        # 确保维度匹配
        pred_i = pred_i.squeeze(0)  # 从[1, num_classes]变为[num_classes]
        
        # 判断预测是否与原始pred一致
        is_correct = (pred_i.argmax() == pred.argmax()).item()
        
        if is_correct:
            # 如果预测一致，我们希望边概率接近1
            target = torch.ones_like(edge_scores)
            sample_loss = F.l1_loss(edge_scores, target) / (edge_scores.shape[0] + 1e-8)
            total_loss += sample_loss
        else:
            # 如果预测不一致，我们希望边概率接近-1
            target = -torch.ones_like(edge_scores)
            sample_loss = F.l1_loss(edge_scores, target) / (edge_scores.shape[0] + 1e-8)
            total_loss += sample_loss
    
    return total_loss / batch_size


class GraphDataset(Dataset):
    def __init__(self, cg_dict, ids):
        self.cg_dict = cg_dict
        self.ids = ids

    def __len__(self):
        return len(self.ids)

    def to_tensor(self, x):
        if isinstance(x, torch.Tensor):
            return x.float()
        else:
            return torch.from_numpy(x).float()

    def __getitem__(self, idx):
        i = self.ids[idx]
        
        # 处理feat是list的情况
        if isinstance(self.cg_dict["feat"], list):
            # 对于BA3数据集，feat列表长度小于样本数量，需要取模
            feat_idx = i % len(self.cg_dict["feat"])
            feat = self.to_tensor(self.cg_dict["feat"][feat_idx])
        else:
            feat = self.to_tensor(self.cg_dict["feat"][i])
            
        # 处理adj是list的情况
        if isinstance(self.cg_dict["adj"], list):
            adj_idx = i % len(self.cg_dict["adj"])
            adj = self.to_tensor(self.cg_dict["adj"][adj_idx])
        else:
            adj = self.to_tensor(self.cg_dict["adj"][i])
            
        label = self.to_tensor(self.cg_dict["label"][i])
        pred = self.to_tensor(self.cg_dict["pred"][0, i])
        return feat, adj, label, pred


class GCNModel(torch.nn.Module):
    def __init__(self, num_node_features, hidden_channels, embedding_dim=20):
        super(GCNModel, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = gcn_conv.GCNConv(num_node_features, hidden_channels)
        self.conv2 = gcn_conv.GCNConv(hidden_channels, hidden_channels)
        self.conv3 = gcn_conv.GCNConv(hidden_channels, embedding_dim)  # 输出嵌入维度
        self.dropout = nn.Dropout(0.3)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv3(x, edge_index)
        return x
    
    def get_node_embeddings(self, x, edge_index):
        """获取节点嵌入，用于边重要性计算"""
        return self.forward(x, edge_index)


class BaseModel(torch.nn.Module):
    def __init__(self, gcn_model, mlp_model_1, classifier, k1, k2, k3, k4, k5, k6, k7, k8, k9):
        super(BaseModel, self).__init__()
        self.gcn_model = gcn_model  # 独立的GNN模型用于节点嵌入
        self.mlp_model_1 = mlp_model_1
        self.k1 = k1
        self.k2 = k2
        self.k3 = k3
        self.k4 = k4
        self.k5 = k5
        self.k6 = k6
        self.k7 = k7
        self.k8 = k8
        self.k9 = k9
        self.classifier = classifier
        # 添加边重要性记忆库
        self.edge_importance_memory = {}

    def get_node_embeddings(self, feat, edge_index):
        """使用独立的GNN模型获取节点嵌入"""
        # 使用GCN模型生成节点嵌入
        node_embeddings = self.gcn_model.get_node_embeddings(feat, edge_index)
        return node_embeddings  # 返回[num_nodes, embedding_dim]

    def get_classifier_features(self, feat, edge_index):
        # 获取classifier的第二层特征
        x = feat.unsqueeze(0)  # [1, num_nodes, num_features]
        
        # 构建邻接矩阵
        num_nodes = x.size(1)
        adj = torch.zeros((1, num_nodes, num_nodes), device=x.device)  # [1, num_nodes, num_nodes]
        adj[0, edge_index[0], edge_index[1]] = 1
        adj[0, edge_index[1], edge_index[0]] = 1  # 确保是无向图
        
        # 第一层
        x, _ = self.classifier.conv_first(x, adj)
        x = self.classifier.act(x)
        if self.classifier.bn:
            x = self.classifier.apply_bn(x)
        
        # 第二层（我们要获取的特征）
        x, _ = self.classifier.conv_block[0](x, adj)
        x = self.classifier.act(x)
        if self.classifier.bn:
            x = self.classifier.apply_bn(x)
        
        # 确保输出维度正确
        x = x.squeeze(0)  # [num_nodes, hidden_dim]
        return x  # 返回[num_nodes, hidden_dim]

    def forward(self, data, train=True):
        feat, adj, label, pred = data
        feat = feat.to(device)
        adj = adj.to(device)
        label = label.to(device)
        pred = pred.to(device)
        
        edge_index_list = []
        classifier_output1 = []
        classifier_output2 = []
        classifier_output3 = []
        classifier_output4 = []
        classifier_output5 = []
        classifier_output6 = []
        classifier_output7 = []
        classifier_output8 = []
        classifier_output9 = []
        classifier_output10 = []
        L1, L2, L3, L4, L5, L6, L7, L8, L9, L10 = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

        for i in range(len(adj)):
            # 计算边索引
            edge_index = adj[i].nonzero(as_tuple=True)
            edge_index_list.append(edge_index)

            # 使用独立的GNN模型生成节点嵌入
            node_vectors = self.get_node_embeddings(feat[i].to(device),
                                                  torch.stack([edge_index[0].to(device), edge_index[1].to(device)]))
            
            # 计算边概率
            edge_probs = self.compute_edge_probs(node_vectors, edge_index)
            
            # 计算子图（使用原始特征）
            subgraph_results = self.compute_subgraphs(
                feat[i], adj[i].to(device), edge_probs, edge_index, self.k1, self.k2, self.k3, self.k4, self.k5, self.k6, self.k7, self.k8, self.k9)
            
            # 解包结果
            (G1_1, G1_2, G2_1, G2_2, G3_1, G3_2, G4_1, G4_2, G5_1, G5_2, 
             G6_1, G6_2, G7_1, G7_2, G8_1, G8_2, G9_1, G9_2,
             adj_trimmed_1, adj_trimmed_2, adj_trimmed_3, adj_trimmed_4, adj_trimmed_5, 
             adj_trimmed_6, adj_trimmed_7, adj_trimmed_8, adj_trimmed_9,
             adj_trimmed_1_2, adj_trimmed_2_2, adj_trimmed_3_2, adj_trimmed_4_2, adj_trimmed_5_2,
             adj_trimmed_6_2, adj_trimmed_7_2, adj_trimmed_8_2, adj_trimmed_9_2,
             topk_edge_indices_1, topk_edge_indices_2, topk_edge_indices_3, topk_edge_indices_4,
             topk_edge_indices_5, topk_edge_indices_6, topk_edge_indices_7, topk_edge_indices_8,
             topk_edge_indices_9) = subgraph_results

            pred_full, _ = self.classifier(feat[i].unsqueeze(0), adj[i].unsqueeze(0))
            pred_1, _ = self.classifier(G1_1.unsqueeze(0), adj_trimmed_1.unsqueeze(0))
            pred_2, _ = self.classifier(G2_1.unsqueeze(0), adj_trimmed_2.unsqueeze(0))
            pred_3, _ = self.classifier(G3_1.unsqueeze(0), adj_trimmed_3.unsqueeze(0))
            pred_4, _ = self.classifier(G4_1.unsqueeze(0), adj_trimmed_4.unsqueeze(0))
            pred_5, _ = self.classifier(G5_1.unsqueeze(0), adj_trimmed_5.unsqueeze(0))
            pred_6, _ = self.classifier(G6_1.unsqueeze(0), adj_trimmed_6.unsqueeze(0))
            pred_7, _ = self.classifier(G7_1.unsqueeze(0), adj_trimmed_7.unsqueeze(0))
            pred_8, _ = self.classifier(G8_1.unsqueeze(0), adj_trimmed_8.unsqueeze(0))
            pred_9, _ = self.classifier(G9_1.unsqueeze(0), adj_trimmed_9.unsqueeze(0))
            # 非关键子图（去掉top-k边）
            pred_1_bar, _ = self.classifier(G1_2.unsqueeze(0), adj_trimmed_1_2.unsqueeze(0))
            pred_2_bar, _ = self.classifier(G2_2.unsqueeze(0), adj_trimmed_2_2.unsqueeze(0))
            pred_3_bar, _ = self.classifier(G3_2.unsqueeze(0), adj_trimmed_3_2.unsqueeze(0))
            pred_4_bar, _ = self.classifier(G4_2.unsqueeze(0), adj_trimmed_4_2.unsqueeze(0))
            pred_5_bar, _ = self.classifier(G5_2.unsqueeze(0), adj_trimmed_5_2.unsqueeze(0))
            pred_6_bar, _ = self.classifier(G6_2.unsqueeze(0), adj_trimmed_6_2.unsqueeze(0))
            pred_7_bar, _ = self.classifier(G7_2.unsqueeze(0), adj_trimmed_7_2.unsqueeze(0))
            pred_8_bar, _ = self.classifier(G8_2.unsqueeze(0), adj_trimmed_8_2.unsqueeze(0))
            pred_9_bar, _ = self.classifier(G9_2.unsqueeze(0), adj_trimmed_9_2.unsqueeze(0))
            classifier_output1.append(pred_full.squeeze(0))
            classifier_output2.append(pred_1.squeeze(0))
            classifier_output3.append(pred_2.squeeze(0))
            classifier_output4.append(pred_3.squeeze(0))
            classifier_output5.append(pred_4.squeeze(0))
            classifier_output6.append(pred_5.squeeze(0))
            classifier_output7.append(pred_6.squeeze(0))
            classifier_output8.append(pred_7.squeeze(0))
            classifier_output9.append(pred_8.squeeze(0))
            classifier_output10.append(pred_9.squeeze(0))

            # === 正负对InfoNCE对比损失 ===
            infonce_losses = []
            tau = 0.1
            S = 9
            pred_full_s = pred_full.squeeze(0)
            pred_list = [pred_1, pred_2, pred_3, pred_4, pred_5, pred_6, pred_7, pred_8, pred_9]
            pred_bar_list = [pred_1_bar, pred_2_bar, pred_3_bar, pred_4_bar, pred_5_bar, pred_6_bar, pred_7_bar, pred_8_bar, pred_9_bar]
            topk_indices_list = [topk_edge_indices_1, topk_edge_indices_2, topk_edge_indices_3, topk_edge_indices_4, topk_edge_indices_5, topk_edge_indices_6, topk_edge_indices_7, topk_edge_indices_8, topk_edge_indices_9]
            alphas = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2]  # S=9
            for s in range(S):
                pred_pos = pred_list[s].squeeze(0)
                pred_neg = pred_bar_list[s].squeeze(0)
                topk_idx = topk_indices_list[s]
                nonkey_mask = torch.ones_like(edge_probs, dtype=torch.bool)
                nonkey_mask[topk_idx] = False
                w_pos = edge_probs[topk_idx].mean() if topk_idx.numel() > 0 else torch.tensor(float('nan'), device=edge_probs.device)
                w_neg = edge_probs[nonkey_mask].mean() if nonkey_mask.any() else torch.tensor(float('nan'), device=edge_probs.device)
                sim_pos = F.cosine_similarity(pred_pos, pred_full_s, dim=0)
                sim_neg = F.cosine_similarity(pred_neg, pred_full_s, dim=0)
                # 检查w_pos, w_neg, sim_pos, sim_neg是否为nan
                if torch.isnan(w_pos):
                    print(f"[DEBUG] w_pos is nan at sample {i}, s={s}")
                if torch.isnan(w_neg):
                    print(f"[DEBUG] w_neg is nan at sample {i}, s={s}")
                if torch.isnan(sim_pos):
                    print(f"[DEBUG] sim_pos is nan at sample {i}, s={s}")
                if torch.isnan(sim_neg):
                    print(f"[DEBUG] sim_neg is nan at sample {i}, s={s}")
                num = torch.exp(alphas[s] * w_pos * sim_pos / tau)
                denom = num + torch.exp(alphas[s] * w_neg * sim_neg / tau)
                loss_infonce = -torch.log(num / (denom + 1e-8))
                if torch.isnan(loss_infonce):
                    print(f"[DEBUG] loss_infonce is nan at sample {i}, s={s}")
                    print(f"num: {num}, denom: {denom}")
                infonce_losses.append(loss_infonce)
            L10 = sum(infonce_losses)
            # === End InfoNCE ===

            # 计算边重要性损失
            edge_importance_loss = edge_importance_bce_loss(
                edge_probs,
                [(edge_index[0][topk_edge_indices_1], edge_index[1][topk_edge_indices_1]),
                 (edge_index[0][topk_edge_indices_2], edge_index[1][topk_edge_indices_2]),
                 (edge_index[0][topk_edge_indices_3], edge_index[1][topk_edge_indices_3]),
                 (edge_index[0][topk_edge_indices_4], edge_index[1][topk_edge_indices_4]),
                 (edge_index[0][topk_edge_indices_5], edge_index[1][topk_edge_indices_5]),
                 (edge_index[0][topk_edge_indices_6], edge_index[1][topk_edge_indices_6]),
                 (edge_index[0][topk_edge_indices_7], edge_index[1][topk_edge_indices_7]),
                 (edge_index[0][topk_edge_indices_8], edge_index[1][topk_edge_indices_8]),
                 (edge_index[0][topk_edge_indices_9], edge_index[1][topk_edge_indices_9])],
                [pred_1, pred_2, pred_3, pred_4, pred_5, pred_6, pred_7, pred_8, pred_9],
                pred[i]
            )
            L11 = 10.0 * edge_importance_loss
            # 检查L10, L11是否为nan
            if torch.isnan(L10):
                print(f"[DEBUG] L10 is nan at sample {i}")
            if torch.isnan(L11):
                print(f"[DEBUG] L11 is nan at sample {i}")
        # 将子图列表转为张量返回
        return L10, L11, classifier_output1, classifier_output2, classifier_output3, classifier_output4, classifier_output5, classifier_output6, classifier_output7, classifier_output8, classifier_output9, classifier_output10

    def compute_edge_probs(self, node_vectors, edge_index):
        # 确保边索引在有效范围内
        num_nodes = node_vectors.size(0)
        edge_src = edge_index[0]
        edge_dst = edge_index[1]
        
        valid_edges = (edge_src < num_nodes) & (edge_dst < num_nodes)
        if not valid_edges.all():
            # 过滤掉无效的边
            valid_src = edge_src[valid_edges]
            valid_dst = edge_dst[valid_edges]
            if valid_src.size(0) == 0:
                # 如果没有有效边，返回零概率
                return torch.zeros(edge_src.size(0), device=device)
        else:
            valid_src = edge_src
            valid_dst = edge_dst
            
        # 获取边的起点和终点特征
        start_nodes = node_vectors[valid_src.to(device)]  # [num_edges, hidden_dim]
        end_nodes = node_vectors[valid_dst.to(device)]    # [num_edges, hidden_dim]

        # 计算边分数 - 使用点积计算相似度，保持特征维度
        edge_scores = start_nodes * end_nodes  # [num_edges, hidden_dim]
        
        # 使用MLP计算边概率
        edge_probs = torch.tanh(self.mlp_model_1(edge_scores))  # [num_edges, 1]
        edge_probs = edge_probs.squeeze(-1)  # [num_edges]
        
        # 如果过滤了边，需要将结果映射回原始边索引
        if not valid_edges.all():
            full_edge_probs = torch.zeros(edge_src.size(0), device=device)
            full_edge_probs[valid_edges] = edge_probs
            return full_edge_probs
        else:
            return edge_probs
    
    def compute_subgraphs(self, original_feat, adj, edge_probs, edge_index, k1, k2, k3, k4, k5, k6, k7, k8, k9):
        # 使用传入的k1到k9参数
        sorted_indices = torch.argsort(edge_probs, descending=True)
        num_edges = sorted_indices.shape[0]
        
        # 计算每个尺度的边索引
        topk_edge_indices_1 = sorted_indices[:max(1, int(k1 * num_edges))]
        topk_edge_indices_2 = sorted_indices[:max(1, int(k2 * num_edges))]
        topk_edge_indices_3 = sorted_indices[:max(1, int(k3 * num_edges))]
        topk_edge_indices_4 = sorted_indices[:max(1, int(k4 * num_edges))]
        topk_edge_indices_5 = sorted_indices[:max(1, int(k5 * num_edges))]
        topk_edge_indices_6 = sorted_indices[:max(1, int(k6 * num_edges))]
        topk_edge_indices_7 = sorted_indices[:max(1, int(k7 * num_edges))]
        topk_edge_indices_8 = sorted_indices[:max(1, int(k8 * num_edges))]
        topk_edge_indices_9 = sorted_indices[:max(1, int(k9 * num_edges))]
        
        # 获取每个尺度的节点索引
        topk_node_indices_1 = torch.unique((torch.cat((edge_index[0][topk_edge_indices_1], edge_index[1][topk_edge_indices_1]), dim=-1)).flatten())
        topk_node_indices_2 = torch.unique((torch.cat((edge_index[0][topk_edge_indices_2], edge_index[1][topk_edge_indices_2]), dim=-1)).flatten())
        topk_node_indices_3 = torch.unique((torch.cat((edge_index[0][topk_edge_indices_3], edge_index[1][topk_edge_indices_3]), dim=-1)).flatten())
        topk_node_indices_4 = torch.unique((torch.cat((edge_index[0][topk_edge_indices_4], edge_index[1][topk_edge_indices_4]), dim=-1)).flatten())
        topk_node_indices_5 = torch.unique((torch.cat((edge_index[0][topk_edge_indices_5], edge_index[1][topk_edge_indices_5]), dim=-1)).flatten())
        topk_node_indices_6 = torch.unique((torch.cat((edge_index[0][topk_edge_indices_6], edge_index[1][topk_edge_indices_6]), dim=-1)).flatten())
        topk_node_indices_7 = torch.unique((torch.cat((edge_index[0][topk_edge_indices_7], edge_index[1][topk_edge_indices_7]), dim=-1)).flatten())
        topk_node_indices_8 = torch.unique((torch.cat((edge_index[0][topk_edge_indices_8], edge_index[1][topk_edge_indices_8]), dim=-1)).flatten())
        topk_node_indices_9 = torch.unique((torch.cat((edge_index[0][topk_edge_indices_9], edge_index[1][topk_edge_indices_9]), dim=-1)).flatten())
        
        # 创建掩码矩阵（使用原始特征的维度）
        mask1 = torch.zeros_like(original_feat).to(device)
        mask2 = torch.zeros_like(original_feat).to(device)
        mask3 = torch.zeros_like(original_feat).to(device)
        mask4 = torch.zeros_like(original_feat).to(device)
        mask5 = torch.zeros_like(original_feat).to(device)
        mask6 = torch.zeros_like(original_feat).to(device)
        mask7 = torch.zeros_like(original_feat).to(device)
        mask8 = torch.zeros_like(original_feat).to(device)
        mask9 = torch.zeros_like(original_feat).to(device)
        
        # 设置掩码
        mask1[topk_node_indices_1] = 1
        mask2[topk_node_indices_2] = 1
        mask3[topk_node_indices_3] = 1
        mask4[topk_node_indices_4] = 1
        mask5[topk_node_indices_5] = 1
        mask6[topk_node_indices_6] = 1
        mask7[topk_node_indices_7] = 1
        mask8[topk_node_indices_8] = 1
        mask9[topk_node_indices_9] = 1
        
        # 创建子图特征（使用原始特征）
        G1_1 = original_feat * mask1
        G1_2 = original_feat * (1 - mask1)
        G2_1 = original_feat * mask2
        G2_2 = original_feat * (1 - mask2)
        G3_1 = original_feat * mask3
        G3_2 = original_feat * (1 - mask3)
        G4_1 = original_feat * mask4
        G4_2 = original_feat * (1 - mask4)
        G5_1 = original_feat * mask5
        G5_2 = original_feat * (1 - mask5)
        G6_1 = original_feat * mask6
        G6_2 = original_feat * (1 - mask6)
        G7_1 = original_feat * mask7
        G7_2 = original_feat * (1 - mask7)
        G8_1 = original_feat * mask8
        G8_2 = original_feat * (1 - mask8)
        G9_1 = original_feat * mask9
        G9_2 = original_feat * (1 - mask9)
        
        # 创建邻接矩阵掩码
        edge_mask1 = torch.zeros_like(adj).to(device)
        edge_mask2 = torch.zeros_like(adj).to(device)
        edge_mask3 = torch.zeros_like(adj).to(device)
        edge_mask4 = torch.zeros_like(adj).to(device)
        edge_mask5 = torch.zeros_like(adj).to(device)
        edge_mask6 = torch.zeros_like(adj).to(device)
        edge_mask7 = torch.zeros_like(adj).to(device)
        edge_mask8 = torch.zeros_like(adj).to(device)
        edge_mask9 = torch.zeros_like(adj).to(device)
        
        # 使用向量化操作创建邻接矩阵
        edge_mask1[edge_index[0][topk_edge_indices_1].to(device), edge_index[1][topk_edge_indices_1].to(device)] = 1
        edge_mask1[edge_index[1][topk_edge_indices_1].to(device), edge_index[0][topk_edge_indices_1].to(device)] = 1
        
        edge_mask2[edge_index[0][topk_edge_indices_2].to(device), edge_index[1][topk_edge_indices_2].to(device)] = 1
        edge_mask2[edge_index[1][topk_edge_indices_2].to(device), edge_index[0][topk_edge_indices_2].to(device)] = 1
        
        edge_mask3[edge_index[0][topk_edge_indices_3].to(device), edge_index[1][topk_edge_indices_3].to(device)] = 1
        edge_mask3[edge_index[1][topk_edge_indices_3].to(device), edge_index[0][topk_edge_indices_3].to(device)] = 1
        
        edge_mask4[edge_index[0][topk_edge_indices_4].to(device), edge_index[1][topk_edge_indices_4].to(device)] = 1
        edge_mask4[edge_index[1][topk_edge_indices_4].to(device), edge_index[0][topk_edge_indices_4].to(device)] = 1
        
        edge_mask5[edge_index[0][topk_edge_indices_5].to(device), edge_index[1][topk_edge_indices_5].to(device)] = 1
        edge_mask5[edge_index[1][topk_edge_indices_5].to(device), edge_index[0][topk_edge_indices_5].to(device)] = 1
        
        edge_mask6[edge_index[0][topk_edge_indices_6].to(device), edge_index[1][topk_edge_indices_6].to(device)] = 1
        edge_mask6[edge_index[1][topk_edge_indices_6].to(device), edge_index[0][topk_edge_indices_6].to(device)] = 1
        
        edge_mask7[edge_index[0][topk_edge_indices_7].to(device), edge_index[1][topk_edge_indices_7].to(device)] = 1
        edge_mask7[edge_index[1][topk_edge_indices_7].to(device), edge_index[0][topk_edge_indices_7].to(device)] = 1
        
        edge_mask8[edge_index[0][topk_edge_indices_8].to(device), edge_index[1][topk_edge_indices_8].to(device)] = 1
        edge_mask8[edge_index[1][topk_edge_indices_8].to(device), edge_index[0][topk_edge_indices_8].to(device)] = 1
        
        edge_mask9[edge_index[0][topk_edge_indices_9].to(device), edge_index[1][topk_edge_indices_9].to(device)] = 1
        edge_mask9[edge_index[1][topk_edge_indices_9].to(device), edge_index[0][topk_edge_indices_9].to(device)] = 1
        
        # 创建修剪后的邻接矩阵
        adj_trimmed_1 = adj * edge_mask1
        adj_trimmed_2 = adj * edge_mask2
        adj_trimmed_3 = adj * edge_mask3
        adj_trimmed_4 = adj * edge_mask4
        adj_trimmed_5 = adj * edge_mask5
        adj_trimmed_6 = adj * edge_mask6
        adj_trimmed_7 = adj * edge_mask7
        adj_trimmed_8 = adj * edge_mask8
        adj_trimmed_9 = adj * edge_mask9
        
        # 创建互补的邻接矩阵
        adj_trimmed_1_2 = adj.clone().to(device) * (1 - edge_mask1)
        adj_trimmed_2_2 = adj.clone().to(device) * (1 - edge_mask2)
        adj_trimmed_3_2 = adj.clone().to(device) * (1 - edge_mask3)
        adj_trimmed_4_2 = adj.clone().to(device) * (1 - edge_mask4)
        adj_trimmed_5_2 = adj.clone().to(device) * (1 - edge_mask5)
        adj_trimmed_6_2 = adj.clone().to(device) * (1 - edge_mask6)
        adj_trimmed_7_2 = adj.clone().to(device) * (1 - edge_mask7)
        adj_trimmed_8_2 = adj.clone().to(device) * (1 - edge_mask8)
        adj_trimmed_9_2 = adj.clone().to(device) * (1 - edge_mask9)
        
        return (G1_1, G1_2, G2_1, G2_2, G3_1, G3_2, G4_1, G4_2, G5_1, G5_2, G6_1, G6_2, G7_1, G7_2, G8_1, G8_2, G9_1, G9_2,
                adj_trimmed_1, adj_trimmed_2, adj_trimmed_3, adj_trimmed_4, adj_trimmed_5, adj_trimmed_6, adj_trimmed_7, adj_trimmed_8, adj_trimmed_9,
                adj_trimmed_1_2, adj_trimmed_2_2, adj_trimmed_3_2, adj_trimmed_4_2, adj_trimmed_5_2, adj_trimmed_6_2, adj_trimmed_7_2, adj_trimmed_8_2, adj_trimmed_9_2,
                topk_edge_indices_1, topk_edge_indices_2, topk_edge_indices_3, topk_edge_indices_4, topk_edge_indices_5, topk_edge_indices_6, topk_edge_indices_7, topk_edge_indices_8, topk_edge_indices_9)

    def compute_loss_1(self, feat, m, adj_trimmed_1, adj_trimmed_2, adj):
        positive_examples = [self.perturb_edges_1(adj_trimmed_2, adj) for _ in range(2)]
        negative_examples = [self.perturb_edges_1(adj_trimmed_1, adj) for _ in range(5)]

        Z_P_list = []
        for positive_example in positive_examples:
            positive_example_edge_index = positive_example.nonzero(as_tuple=True)
            edge_index = torch.stack([
                positive_example_edge_index[0].clone().detach(),
                positive_example_edge_index[1].clone().detach()
            ]).to(device)
            Z_P = self.get_classifier_features(feat.to(device), edge_index)
            Z_P = torch.mean(Z_P, dim=0)  # [hidden_dim]
            Z_P_list.append(Z_P)  # [hidden_dim]

        temperature = 0.1
        similarity_P = torch.nn.functional.cosine_similarity(Z_P_list[0], Z_P_list[1], dim=0) / temperature

        similarity_N = []
        for negative_example in negative_examples:
            negative_example_edge_index = negative_example.nonzero(as_tuple=True)
            edge_index = torch.stack([
                negative_example_edge_index[0].clone().detach(),
                negative_example_edge_index[1].clone().detach()
            ]).to(device)
            Z_N = self.get_classifier_features(feat.to(device), edge_index)
            Z_N = torch.mean(Z_N, dim=0)  # [hidden_dim]
            similarity_N.append(torch.nn.functional.cosine_similarity(Z_P_list[0], Z_N, dim=0) / temperature)

        similarity_N_exp = torch.exp(torch.stack(similarity_N))
        denominator = torch.sum(similarity_N_exp) + torch.exp(similarity_P) + 1e-8
        contrastive_loss = -torch.log(torch.exp(similarity_P) / denominator)
        return contrastive_loss / (len(negative_examples) + 1)

    def perturb_edges_1(self, adj_trimmed, adj):
        # Calculate the number of edges to remove
        num_edges_to_remove = int(0.5 * adj_trimmed.sum().item())
        edge_indices = adj_trimmed.nonzero(as_tuple=True)
        indices_to_remove = torch.randperm(edge_indices[0].size(0))[:num_edges_to_remove]
        perturbed_adj_trimmed = adj.clone()
        perturbed_adj_trimmed[edge_indices[0][indices_to_remove], edge_indices[1][indices_to_remove]] = 0
        perturbed_adj_trimmed[edge_indices[1][indices_to_remove], edge_indices[0][indices_to_remove]] = 0
        return perturbed_adj_trimmed


class MyModel(nn.Module):
    def __init__(self, model):
        super(MyModel, self).__init__()
        self.model = model
        
    def forward(self, x, train=True):
        # 获取模型输出
        L2, L5, pred1, pred2, pred3, pred4, pred5, pred6, pred7, pred8, pred9, pred10 = self.model(x, train)
        return L2, L5, pred1, pred2, pred3, pred4, pred5, pred6, pred7, pred8, pred9, pred10


class MLPModel(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLPModel, self).__init__()
        self.fc1 = torch.nn.Linear(input_size, hidden_size)  # 使用传入的input_size
        self.fc2 = torch.nn.Linear(hidden_size, hidden_size)
        self.fc3 = torch.nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(0.2)
        self.ln1 = nn.LayerNorm(hidden_size)
        self.ln2 = nn.LayerNorm(hidden_size)

    def forward(self, x):
        # x: [batch_size, input_size]
        x = self.fc1(x)  # [batch_size, hidden_size]
        x = self.ln1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)  # [batch_size, hidden_size]
        x = self.ln2(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc3(x)  # [batch_size, output_size]
        return x


def train(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    total_L2 = 0
    total_L5 = 0
    # 初始化损失统计字典
    loss_stats = {
        'contrastive': [],
        'edge_importance': []
    }
    # 初始化梯度统计字典
    gradient_stats = {
        'gcn_model': [],
        'mlp_model_1': [],
        'mlp_model_2': [],
        'classifier': []
    }
    correct_pred1 = correct_pred2 = correct_pred3 = correct_pred4 = correct_pred5 = correct_pred6 = correct_pred7 = correct_pred8 = correct_pred9 = correct_pred10 = 0
    total_samples = 0
    for data in dataloader:
        optimizer.zero_grad()
        # 获取模型输出
        L2, L5, pred1, pred2, pred3, pred4, pred5, pred6, pred7, pred8, pred9, pred10 = model(data, train=True)
        # 转换为张量以计算准确率
        pred1 = torch.stack(pred1)
        pred2 = torch.stack(pred2)
        pred3 = torch.stack(pred3)
        pred4 = torch.stack(pred4)
        pred5 = torch.stack(pred5)
        pred6 = torch.stack(pred6)
        pred7 = torch.stack(pred7)
        pred8 = torch.stack(pred8)
        pred9 = torch.stack(pred9)
        pred10 = torch.stack(pred10)
        # 累计准确率
        correct_pred1 += (pred1.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
        correct_pred2 += (pred2.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
        correct_pred3 += (pred3.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
        correct_pred4 += (pred4.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
        correct_pred5 += (pred5.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
        correct_pred6 += (pred6.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
        correct_pred7 += (pred7.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
        correct_pred8 += (pred8.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
        correct_pred9 += (pred9.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
        correct_pred10 += (pred10.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
        # 计算总损失，调整权重
        loss = L2+L5
        loss.backward()
        # 记录梯度
        for name, param in model.named_parameters():
            if param.grad is not None:
                grad_norm = param.grad.norm().item()
                if 'gcn_model' in name:
                    gradient_stats['gcn_model'].append(grad_norm)
                elif 'mlp_model_1' in name:
                    gradient_stats['mlp_model_1'].append(grad_norm)
                elif 'mlp_model_2' in name:
                    gradient_stats['mlp_model_2'].append(grad_norm)
                elif 'classifier' in name:
                    gradient_stats['classifier'].append(grad_norm)
        optimizer.step()
        # 记录各种损失
        loss_stats['contrastive'].append(L2.item() if hasattr(L2, 'item') else L2)
        loss_stats['edge_importance'].append(L5.item() if hasattr(L5, 'item') else L5)
        # 累计损失
        total_L2 += L2.item() if hasattr(L2, 'item') else L2
        total_L5 += L5.item() if hasattr(L5, 'item') else L5
        total_loss += (L2.item() if hasattr(L2, 'item') else L2)
        total_samples += data[2].size(0)
    # 计算平均损失
    avg_L2 = total_L2 / len(dataloader)
    avg_L5 = total_L5 / len(dataloader)
    avg_loss = total_loss / len(dataloader)
    # 计算并输出梯度统计信息
    print("\n=== 梯度统计 ===")
    for component, gradients in gradient_stats.items():
        if gradients:
            avg_grad = sum(gradients) / len(gradients)
            max_grad = max(gradients)
            min_grad = min(gradients)
            print(f"{component}:")
            print(f"  平均梯度: {avg_grad:.6f}")
            print(f"  最大梯度: {max_grad:.6f}")
            print(f"  最小梯度: {min_grad:.6f}")
    print("\n=== 损失统计 ===")
    print(f"对比损失: {avg_L2:.6f}")
    print(f"边重要性损失: {avg_L5:.6f}")
    print(f"总损失: {avg_loss:.6f}")
    return (
        avg_loss,
        correct_pred1 / total_samples,
        correct_pred2 / total_samples,
        correct_pred3 / total_samples,
        correct_pred4 / total_samples,
        correct_pred5 / total_samples,
        correct_pred6 / total_samples,
        correct_pred7 / total_samples,
        correct_pred8 / total_samples,
        correct_pred9 / total_samples,
        correct_pred10 / total_samples,
    )


def test(model, dataloader, device):
    model.eval()
    total_loss = 0
    total_samples = 0
    total_L2 = 0
    total_L5 = 0
    correct_pred1 = correct_pred2 = correct_pred3 = correct_pred4 = correct_pred5 = correct_pred6 = correct_pred7 = correct_pred8 = correct_pred9 = correct_pred10 = 0
    with torch.no_grad():
        for data in dataloader:
            L2, L5, pred1, pred2, pred3, pred4, pred5, pred6, pred7, pred8, pred9, pred10 = model(data, train=False)
            # 转换为张量以计算准确率
            pred1 = torch.stack(pred1)
            pred2 = torch.stack(pred2)
            pred3 = torch.stack(pred3)
            pred4 = torch.stack(pred4)
            pred5 = torch.stack(pred5)
            pred6 = torch.stack(pred6)
            pred7 = torch.stack(pred7)
            pred8 = torch.stack(pred8)
            pred9 = torch.stack(pred9)
            pred10 = torch.stack(pred10)
            # 累计准确率
            correct_pred1 += (pred1.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
            correct_pred2 += (pred2.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
            correct_pred3 += (pred3.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
            correct_pred4 += (pred4.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
            correct_pred5 += (pred5.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
            correct_pred6 += (pred6.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
            correct_pred7 += (pred7.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
            correct_pred8 += (pred8.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
            correct_pred9 += (pred9.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
            correct_pred10 += (pred10.argmax(dim=1) == data[3].argmax(dim=1)).sum().item()
            # 累计损失
            total_L2 += L2.item() if hasattr(L2, 'item') else L2
            total_L5 += L5.item() if hasattr(L5, 'item') else L5
            total_loss += (L2.item() if hasattr(L2, 'item') else L2)
            total_samples += data[2].size(0)
    # 平均损失和准确率
    avg_L2 = total_L2 / len(dataloader)
    avg_L5 = total_L5 / len(dataloader)
    avg_loss = total_loss / len(dataloader)
    print(f"Average L2 Loss: {avg_L2}, Average L5 Loss: {avg_L5}")
    return (
        avg_loss,
        correct_pred1 / total_samples,
        correct_pred2 / total_samples,
        correct_pred3 / total_samples,
        correct_pred4 / total_samples,
        correct_pred5 / total_samples,
        correct_pred6 / total_samples,
        correct_pred7 / total_samples,
        correct_pred8 / total_samples,
        correct_pred9 / total_samples,
        correct_pred10 / total_samples,
    )


# 主程序
if __name__ == '__main__':
    prog_args = parser.parse_args()
    
    # 获取数据集配置
    dataset_config = get_dataset_config(prog_args.dataset)
    print(f"Loading dataset: {prog_args.dataset}")
    print(f"Dataset config: {dataset_config}")

    # Load a configuration: auto-select CUDA if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 初始化记录最高准确率的变量
    best_accuracies = {
        'original': 0.0,
        '10%': 0.0,
        '20%': 0.0,
        '30%': 0.0,
        '40%': 0.0,
        '50%': 0.0,
        '60%': 0.0,
        '70%': 0.0,
        '80%': 0.0,
        '90%': 0.0
    }

    # Load a model checkpoint based on dataset
    ckpt_path = os.path.join('ckpt', dataset_config['ckpt_file'])
    print(f"Loading checkpoint from: {ckpt_path}")
    cg_ckpt = torch.load(ckpt_path, weights_only=False)
    cg_dict = cg_ckpt["cg"]  # get computation graph

    # 打印数据集维度信息
    if isinstance(cg_dict["feat"], list):
        input_dim = cg_dict["feat"][0].shape[2]  # 对于list类型，取第一个元素的特征维度
        print(f"\n=== 数据集维度信息 ===")
        print(f"特征维度 (input_dim): {input_dim}")
        print(f"类别数量 (num_classes): {cg_dict['pred'].shape[2]}")
        print(f"特征类型: list, 长度: {len(cg_dict['feat'])}")
        print(f"第一个特征形状: {cg_dict['feat'][0].shape}")
        print(f"预测形状 (pred shape): {cg_dict['pred'].shape}")
    else:
        input_dim = cg_dict["feat"].shape[2]
        num_classes = cg_dict["pred"].shape[2]
        print(f"\n=== 数据集维度信息 ===")
        print(f"特征维度 (input_dim): {input_dim}")
        print(f"类别数量 (num_classes): {num_classes}")
        print(f"特征形状 (feat shape): {cg_dict['feat'].shape}")
        print(f"预测形状 (pred shape): {cg_dict['pred'].shape}")
    
    num_classes = cg_dict["pred"].shape[2]

    # 重新划分训练集和测试集的索引
    train_idx = cg_dict["train_idx"][0:]
    test_idx = cg_dict["test_idx"][0:]

    # 划分为新的训练集和测试集
    train_idx_new = train_idx[:]
    test_idx_new = test_idx[:]

    # 创建数据集和dataloader
    train_dataset = GraphDataset(cg_dict, train_idx_new)
    test_dataset = GraphDataset(cg_dict, test_idx_new)

    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # 根据数据集调整hidden_dim
    if prog_args.dataset in ['BA3', 'ba3_base_h64_o32', 'mnist_balanced_base_h64_o32_6layer']:
        # MNIST数据集使用64维隐藏层和32维嵌入
        hidden_dim = 64
        embedding_dim = 32
    else:
        # 其他数据集使用20维隐藏层
        hidden_dim = 20
        embedding_dim = 20

    # 根据数据集动态调整模型参数
    # MNIST数据集：特征维度5，类别数10，使用64维隐藏层和32维嵌入
    gcn_model = GCNModel(num_node_features=input_dim, hidden_channels=hidden_dim, embedding_dim=embedding_dim).to(device)
    mlp_model_1 = MLPModel(input_size=embedding_dim, hidden_size=16, output_size=1).to(device)  # 输入维度与GCN模型输出维度匹配
    
    # 根据数据集调整层数
    if prog_args.dataset == 'mnist_balanced_base_h64_o32_6layer':
        num_layers = 6  # MNIST数据集使用6层网络
    else:
        num_layers = 3  # 其他数据集使用3层网络
    
    classifier = models.GcnEncoderGraph(
        input_dim=input_dim,
        hidden_dim=hidden_dim,
        embedding_dim=embedding_dim,
        label_dim=num_classes,
        num_layers=num_layers,
        bn=False,
        args=argparse.Namespace(gpu=prog_args.gpu, bias=True, method=None)
    ).to(device)

    # 加载预训练权重并判断是否成功
    classifier_ckpt_path = os.path.join('ckpt', dataset_config['ckpt_file'])
    try:
        ckpt = torch.load(classifier_ckpt_path, map_location=device, weights_only=False)
        classifier.load_state_dict(ckpt["model_state"])
        print("成功加载 classifier 预训练权重！")
        # 打印部分参数均值做 sanity check
        for name, param in classifier.named_parameters():
            print(f"{name} mean: {param.data.mean().item():.6f}")
    except Exception as e:
        print(f"加载 classifier 预训练权重失败: {e}")
        print("classifier 将使用随机初始化参数！")

    # 冻结 classifier 的参数
    for param in classifier.parameters():
        param.requires_grad = False

    # 加载最终模型参数
    final_ckpt_name = f'model_checkpoint_final_{prog_args.dataset.lower()}.pth'
    try:
        final_ckpt = torch.load(final_ckpt_name, map_location=device)
        model_state = final_ckpt['model_state_dict']
        base_model = BaseModel(gcn_model, mlp_model_1, classifier, 
                              k1=0.1, k2=0.2, k3=0.3, k4=0.4, k5=0.5, 
                              k6=0.6, k7=0.7, k8=0.8, k9=0.9).to(device)
        model = MyModel(base_model).to(device)
        model.load_state_dict(model_state, strict=False)
        print(f"已成功导入 {final_ckpt_name} 检查点！")
    except Exception as e:
        print(f"无法加载 {final_ckpt_name} 检查点: {e}")
        # 这里要重新初始化模型
        base_model = BaseModel(gcn_model, mlp_model_1, classifier, 
                              k1=0.1, k2=0.2, k3=0.3, k4=0.4, k5=0.5, 
                              k6=0.6, k7=0.7, k8=0.8, k9=0.9).to(device)
        model = MyModel(base_model).to(device)
        print("模型将使用随机初始化的参数...")

    # 获取所有模型的参数
    all_parameters = list(model.parameters())
    
    # 打印GCN模型的参数是否可训练
    for name, param in gcn_model.named_parameters():
        print(name, param.requires_grad)
    
    # 打印MLP1的参数是否可训练
    for name, param in mlp_model_1.named_parameters():
        print(name, param.requires_grad)
    
    # 创建优化器，更新所有参数
    # MNIST数据集使用较小的学习率
    optimizer = torch.optim.AdamW(all_parameters, lr=0.001, weight_decay=0.01)
    
    # 使用学习率调度器
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.8)

    # 设置起始epoch为0
    start_epoch = 0

    # === 新增：记录最佳模型相关变量 ===
    best_sum_acc = 0.0
    best_model_state = None
    best_accs = None
    best_epoch = 0

    # 在训练开始前进行初始测试
    print("\n=== 初始测试阶段 ===")
    print("测试模型在初始化状态下的性能...")
    initial_test_loss, initial_test_acc_pred1, initial_test_acc_pred2, initial_test_acc_pred3, initial_test_acc_pred4, initial_test_acc_pred5, initial_test_acc_pred6, initial_test_acc_pred7, initial_test_acc_pred8, initial_test_acc_pred9, initial_test_acc_pred10 = test(model, test_dataloader, device)
    
    # 计算原始图对真实标签的预测准确率
    print("\n=== 原始图对真实标签的预测准确率测试 ===")
    model.eval()
    correct_original_vs_true = 0
    total_samples = 0
    with torch.no_grad():
        for data in test_dataloader:
            feat, adj, label, pred = data
            feat = feat.to(device)
            adj = adj.to(device)
            label = label.to(device)
            pred = pred.to(device)
            
            for i in range(len(adj)):
                # 计算边索引
                edge_index = adj[i].nonzero(as_tuple=True)
                edge_index = torch.stack([edge_index[0].to(device), edge_index[1].to(device)])
                
                # 使用classifier进行预测
                pred_original, _ = classifier(feat[i].unsqueeze(0), adj[i].unsqueeze(0))
                pred_original = pred_original.squeeze(0)
                
                # 比较预测结果与真实标签（使用data[2]，即label）
                true_label = label[i]
                
                # 处理真实标签的维度 - 参考test_classifier_performance.py的逻辑
                if true_label.dim() > 1 and true_label.shape[0] > 1:
                    # 一热编码格式，转换为类别索引
                    true_class = true_label.argmax().item()
                else:
                    # 已经是类别索引格式
                    if true_label.dim() > 1:
                        true_label = true_label.squeeze()
                    # 对于标量（0维张量），直接转换为整数
                    if true_label.dim() == 0:
                        true_class = true_label.item()
                    else:
                        true_class = true_label.argmax().item()
                
                pred_class = pred_original.argmax().item()
                
                if pred_class == true_class:
                    correct_original_vs_true += 1
                total_samples += 1
    
    original_vs_true_accuracy = correct_original_vs_true / total_samples
    print(f"原始图对真实标签的预测准确率: {original_vs_true_accuracy:.4f}")
    
    # 同时计算原始图对模型预测标签的准确率（用于对比）
    correct_original_vs_pred = 0
    total_samples_pred = 0
    with torch.no_grad():
        for data in test_dataloader:
            feat, adj, label, pred = data
            feat = feat.to(device)
            adj = adj.to(device)
            label = label.to(device)
            pred = pred.to(device)
            
            for i in range(len(adj)):
                # 使用classifier进行预测
                pred_original, _ = classifier(feat[i].unsqueeze(0), adj[i].unsqueeze(0))
                pred_original = pred_original.squeeze(0)
                
                # 比较预测结果与模型预测标签（使用data[3]，即pred）
                model_pred = pred[i]
                if model_pred.dim() > 1:
                    model_pred = model_pred.squeeze()
                
                if pred_original.argmax() == model_pred.argmax():
                    correct_original_vs_pred += 1
                total_samples_pred += 1
    
    original_vs_pred_accuracy = correct_original_vs_pred / total_samples_pred
    print(f"原始图对模型预测标签的准确率: {original_vs_pred_accuracy:.4f}")
    
    print("\n初始测试结果:")
    print(f"初始测试损失: {initial_test_loss}")
    print(f"原始图初始测试准确率: {initial_test_acc_pred1}")
    print(f"10%子图初始测试准确率: {initial_test_acc_pred2}")
    print(f"20%子图初始测试准确率: {initial_test_acc_pred3}")
    print(f"30%子图初始测试准确率: {initial_test_acc_pred4}")
    print(f"40%子图初始测试准确率: {initial_test_acc_pred5}")
    print(f"50%子图初始测试准确率: {initial_test_acc_pred6}")
    print(f"60%子图初始测试准确率: {initial_test_acc_pred7}")
    print(f"70%子图初始测试准确率: {initial_test_acc_pred8}")
    print(f"80%子图初始测试准确率: {initial_test_acc_pred9}")
    print(f"90%子图初始测试准确率: {initial_test_acc_pred10}")
    print("\n=== 开始训练阶段 ===")

    # MNIST数据集训练300轮
    for epoch in range(start_epoch, start_epoch + 300):
        time1 = time.time()

        with torch.no_grad():
            # 输出GCN模型的参数
            gcn_model_params_info = [f"{name}: {param.mean().item():.6f}" for name, param in
                                   gcn_model.named_parameters()]
            print(f"Start Epoch {epoch} - GCN: " + ", ".join(gcn_model_params_info))
            
            # 输出MLP模型的参数
            mlp_model_1_params_info = [f"{name}: {param.mean().item():.6f}" for name, param in
                                     mlp_model_1.named_parameters()]
            print(f"Start Epoch {epoch} - MLP1: " + ", ".join(mlp_model_1_params_info))

        # Train phase
        train_loss, train_acc_pred1, train_acc_pred2, train_acc_pred3, train_acc_pred4, train_acc_pred5, train_acc_pred6, train_acc_pred7, train_acc_pred8, train_acc_pred9, train_acc_pred10 = train(model, train_dataloader, optimizer, device)
        print(f"Epoch: {epoch}, Train Loss: {train_loss}")

        # Test phase
        test_loss, test_acc_pred1, test_acc_pred2, test_acc_pred3, test_acc_pred4, test_acc_pred5, test_acc_pred6, test_acc_pred7, test_acc_pred8, test_acc_pred9, test_acc_pred10 = test(model, test_dataloader, device)

        print(f"Test Loss: {test_loss}")
        # 输出训练集和测试集准确率
        print(f"原始图预测准确率: {train_acc_pred1}")
        print(f"10%子图预测准确率: {train_acc_pred2}")
        print(f"20%子图预测准确率: {train_acc_pred3}")
        print(f"30%子图预测准确率: {train_acc_pred4}")
        print(f"40%子图预测准确率: {train_acc_pred5}")
        print(f"50%子图预测准确率: {train_acc_pred6}")
        print(f"60%子图预测准确率: {train_acc_pred7}")
        print(f"70%子图预测准确率: {train_acc_pred8}")
        print(f"80%子图预测准确率: {train_acc_pred9}")
        print(f"90%子图预测准确率: {train_acc_pred10}")

        print(f"原始图测试准确率: {test_acc_pred1}")
        print(f"10%子图测试准确率: {test_acc_pred2}")
        print(f"20%子图测试准确率: {test_acc_pred3}")
        print(f"30%子图测试准确率: {test_acc_pred4}")
        print(f"40%子图测试准确率: {test_acc_pred5}")
        print(f"50%子图测试准确率: {test_acc_pred6}")
        print(f"60%子图测试准确率: {test_acc_pred7}")
        print(f"70%子图测试准确率: {test_acc_pred8}")
        print(f"80%子图测试准确率: {test_acc_pred9}")
        print(f"90%子图测试准确率: {test_acc_pred10}")

        # === 新增：判断并保存最佳模型 ===
        current_accs = [
            test_acc_pred1, test_acc_pred2, test_acc_pred3, test_acc_pred4, test_acc_pred5,
            test_acc_pred6, test_acc_pred7, test_acc_pred8, test_acc_pred9, test_acc_pred10
        ]
        current_sum_acc = sum(current_accs)
        if current_sum_acc > best_sum_acc:
            best_sum_acc = current_sum_acc
            best_model_state = model.state_dict()
            best_accs = current_accs.copy()
            best_epoch = epoch
            best_model_save_path = f"best_model_{prog_args.dataset.lower()}.pth"
            torch.save({
                'epoch': epoch,
                'model_state_dict': best_model_state,
                'optimizer_state_dict': optimizer.state_dict(),
            }, best_model_save_path)
            print(f"Best model updated at epoch {epoch} (sum acc: {best_sum_acc:.4f}) and saved to {best_model_save_path}")

        time2 = time.time()
        print(f"Epoch time: {time2 - time1}")

        # 更新学习率
        # scheduler.step()
        
        # 输出当前学习率
        # print(f"Current learning rate: {optimizer.param_groups[0]['lr']}")
        
        # 定期保存模型检查点
        if epoch % 10 == 0 or epoch == start_epoch + 30:
            model_save_path = f"model_checkpoint_epoch{epoch}_{prog_args.dataset.lower()}.pth"
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                # 'scheduler_state_dict': scheduler.state_dict(),
            }, model_save_path)
            
            print(f"Model checkpoint saved to {model_save_path}")

    # 训练结束，保存最终模型
    final_model_save_path = f"model_checkpoint_final_{prog_args.dataset.lower()}.pth"
    torch.save({
        'epoch': start_epoch + 299,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        # 'scheduler_state_dict': scheduler.state_dict(),
    }, final_model_save_path)

    print(f"Final model saved to {final_model_save_path}")

    # === 新增：输出最佳模型信息 ===
    print(f"\n=== 最佳模型信息 ({prog_args.dataset}) ===")
    print(f"最佳模型出现在 epoch: {best_epoch}")
    print(f"各尺度下的最佳准确率：")
    for i, acc in enumerate(best_accs):
        if i == 0:
            print(f"原始图: {acc:.4f}")
        else:
            print(f"{i*10}%子图: {acc:.4f}")
    print(f"准确率之和: {best_sum_acc:.4f}")
    best_model_save_path = f"best_model_{prog_args.dataset.lower()}.pth"
    print(f"最佳模型已保存为 {best_model_save_path}")

