import os
import numpy as np
import torch
import dgl
import torch.optim as optim
import torch.nn.functional as F
from model import *
from utils import *
import json
from sklearn.metrics import f1_score, accuracy_score, recall_score, roc_auc_score, average_precision_score, confusion_matrix

# 导入混合精度训练所需的库
from torch.cuda.amp import autocast, GradScaler

import warnings
warnings.filterwarnings('ignore')
def sigmoid_rampup(current, rampup_length):
    '''Exponential rampup from https://arxiv.org/abs/1610.02242'''
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))

def get_current_mu(epoch, args):
    if args.mu_rampup:
        # Consistency ramp-up from https://arxiv.org/abs/1610.02242
        if args.consistency_rampup is None:
            args.consistency_rampup = 500
        return args.mu * sigmoid_rampup(epoch, args.consistency_rampup)
    else:
        return args.mu

def initialize_centroids(features, k):
    """使用k-means++策略初始化聚类中心"""
    num_nodes = features.size(0)
    centroids = torch.zeros(k, features.size(1), device=features.device)
    
    # 随机选择第一个中心
    first_id = torch.randint(num_nodes, (1,)).item()
    centroids[0] = features[first_id]
    
    # 选择剩余的中心
    for i in range(1, k):
        # 计算到最近中心的距离
        distances = torch.min(torch.cdist(features, centroids[:i]), dim=1)[0]
        # 按概率选择下一个中心
        probabilities = distances / distances.sum()
        next_id = torch.multinomial(probabilities, 1).item()
        centroids[i] = features[next_id]
    
    return centroids

def check_convergence(centroids, prev_centroids, tol=1e-4):

    """检查聚类是否收敛"""
    return torch.norm(centroids - prev_centroids) < tol

def robust_node_clustering(features, k=2, temperature=0.1, max_iterations=10, labeled_features=None, labeled_classes=None):
    """基于论文的鲁棒节点聚类方法
    
    Args:
        features: 原始图的节点特征 [num_nodes, feature_dim]
        k: 聚类数量(默认2，对应二分类)
        temperature: 温度参数，控制软分配的软硬程度
        max_iterations: 最大迭代次数
        labeled_features: 有标签样本的特征 [num_labeled, feature_dim]
        labeled_classes: 有标签样本的标签 [num_labeled]
    
    Returns:
        tuple: (
            original_cluster_assignments: 原始图的聚类分配 [num_nodes, k]
            view1_cluster_assignments: 增强视图1的聚类分配 [num_nodes, k]
            view2_cluster_assignments: 增强视图2的聚类分配 [num_nodes, k]
            centroids: 聚类中心 [k, feature_dim]
        )
    """
    num_nodes = features.size(0)
    feature_dim = features.size(1)
    device = features.device
    
    # 确保features的类型是float32，避免与centroids类型不匹配
    features = features.to(torch.float32)
    
    # 聚类迭代过程不需要梯度，使用no_grad包裹
    with torch.no_grad():
        # 检查是否提供了有标签样本作为聚类中心
        if labeled_features is not None and labeled_classes is not None:
            # 确保labeled_features也是float32类型
            labeled_features = labeled_features.to(torch.float32)
            
            # 使用有标签样本初始化聚类中心
            centroids = torch.zeros(k, feature_dim, device=device, dtype=torch.float32)
            
            # 按类别分组有标签样本
            for i in range(k):
                # 找到标签为i的样本
                class_indices = torch.where(labeled_classes == i)[0]
                if len(class_indices) > 0:
                    # 如果有该类的样本，计算这些样本的平均特征作为中心
                    centroids[i] = labeled_features[class_indices].mean(dim=0)
                else:
                    # 如果没有该类的样本，随机初始化
                    centroids[i] = torch.randn(feature_dim, device=device, dtype=torch.float32)
                    centroids[i] = F.normalize(centroids[i], p=2, dim=0)  # 归一化
                    
            # 规范化聚类中心 - 确保它们具有相同的范数
            norms = torch.norm(centroids, dim=1, keepdim=True)
            centroids = centroids / (norms + 1e-10)  # 避免除以零
            
        else:
            # 如果没有提供有标签样本，使用原始的k-means++初始化策略
            # 注意：只使用原始图特征进行中心初始化
            centroids = initialize_centroids(features, k)
        
        # 记录初始的聚类中心用于检查收敛
        prev_centroids = centroids.clone()
        
        # 只有在没有提供标签数据时才进行迭代优化
        if labeled_features is None or labeled_classes is None:
            # 迭代优化 - 完全不需要梯度
            for iter in range(max_iterations):
                # 计算每个节点到各个聚类中心的距离 - 只使用原始图特征
                distances = torch.cdist(features, centroids)  # [num_nodes, k]
                
                # 软分配 (使用Gumbel-Softmax进行可微分的聚类分配)
                logits = -distances / temperature
                cluster_assignments = F.gumbel_softmax(logits, tau=temperature, hard=False)
                
                # 更新聚类中心 - 只使用原始图特征
                new_centroids = torch.zeros_like(centroids)
                for j in range(k):
                    weights = cluster_assignments[:, j].unsqueeze(1)  # [num_nodes, 1]
                    if weights.sum() > 0:  # 避免除以零
                        new_centroids[j] = (features * weights).sum(0) / weights.sum()
                    else:
                        new_centroids[j] = centroids[j].clone()  # 保持原来的中心
                
                # 使用新的张量替代原有张量
                centroids = new_centroids
                    
                # 检查收敛
                if check_convergence(centroids, prev_centroids, tol=1e-4):
                    break
                    
                prev_centroids = centroids.clone()
    
    # 重新计算最终的聚类分配（在梯度环境下使用不同视图的features，保留梯度）
    # 为原始图特征计算聚类分配
    distances_original = torch.cdist(features, centroids)  # [num_nodes, k]
    logits_original = -distances_original / temperature
    original_cluster_assignments = F.gumbel_softmax(logits_original, tau=temperature, hard=False)
    
   
    view1_cluster_assignments = original_cluster_assignments

    
    view2_cluster_assignments = original_cluster_assignments

    # 计算聚类结果的统计信息
    with torch.no_grad():
        hard_assignments = torch.argmax(original_cluster_assignments, dim=1)
        num_class_0 = torch.sum(hard_assignments == 0).item()
        num_class_1 = torch.sum(hard_assignments == 1).item()
        total = num_class_0 + num_class_1
    
    return original_cluster_assignments, view1_cluster_assignments, view2_cluster_assignments, centroids

def compute_clustering_loss(features, cluster_assignments, centroids, epsilon=1e-6):
    features = F.normalize(features, p=2, dim=1)
    centroids = F.normalize(centroids, p=2, dim=1)
    
    with torch.no_grad():
        hard_assignments = torch.argmax(cluster_assignments, dim=1)
        pos_indices = torch.nonzero(hard_assignments == 1).squeeze(-1)
        neg_indices = torch.nonzero(hard_assignments == 0).squeeze(-1)
        num_pos = pos_indices.numel()
        num_neg = neg_indices.numel()
        total = num_pos + num_neg + epsilon
        pos_weight = num_neg / total if num_pos > 0 else 0.0
        neg_weight = num_pos / total if num_neg > 0 else 0.0

    distances = torch.cdist(features, centroids)  # [N, K]
    intra_positive_loss = torch.mean(distances[pos_indices, 1]) if num_pos > 0 else torch.tensor(0.0, device=features.device)
    intra_negative_loss = torch.mean(distances[neg_indices, 0]) if num_neg > 0 else torch.tensor(0.0, device=features.device)
    intra_loss =  intra_positive_loss +  intra_negative_loss

    centroid_dists = torch.pdist(centroids)
    inter_loss = -torch.mean(centroid_dists)

    expanded_centroids = torch.index_select(centroids, 0, hard_assignments)
    compactness = torch.mean(torch.sum((features - expanded_centroids) ** 2, dim=1))
    joint_reg = compactness / (torch.mean(centroid_dists) + epsilon)

    total_loss = 0.5 * intra_loss + 0.5 * inter_loss + 0.1 * joint_reg
    return total_loss, num_pos, num_neg

def get_augmented_view(features, drop_rate=0.2):
    """特征增强视图生成函数
    
    Args:
        features: 输入特征
        drop_rate: 特征丢弃率
    Returns:
        增强后的特征
    """
    feat_mask = torch.rand(features.size(1), device=features.device) > drop_rate
    feat_aug = features.clone()
    feat_aug[:, ~feat_mask] = 0
    return feat_aug

def nt_xent_loss(z_i, z_j, temperature=0.05):
    """对比学习损失函数
    
    Args:
        z_i: 第一个视图的特征
        z_j: 第二个视图的特征
        temperature: 温度参数
    Returns:
        对比损失值
    """
    # 特征归一化
    z_i = F.normalize(z_i, dim=-1)
    z_j = F.normalize(z_j, dim=-1)
    
    # 拼接两个视图的特征
    representations = torch.cat([z_i, z_j], dim=0)
    
    # 计算相似度矩阵
    sim_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1)
    
    # 创建正负样本对的标签
    batch_size = z_i.size(0)
    labels = torch.arange(batch_size, device=z_i.device)
    labels = torch.cat([labels, labels], dim=0)
    masks = labels[:, None] == labels[None, :]
    
    # 去除自身相似度
    mask_diag = ~torch.eye(labels.size(0), dtype=torch.bool, device=z_i.device)
    sim_matrix = sim_matrix[mask_diag].view(labels.size(0), -1)
    masks = masks[mask_diag].view(labels.size(0), -1)
    
    # 计算InfoNCE损失
    nominator = torch.exp(sim_matrix / temperature)[masks].view(labels.size(0), -1).sum(dim=-1)
    denominator = torch.sum(torch.exp(sim_matrix / temperature), dim=-1)
    loss = -torch.log(nominator / denominator).mean()
    
    return loss
def generate_contrastive_pairs(batch_nodes, labels):
            """
            根据给定的batch nodes生成正样本对和负样本对。
            
            :param batch_nodes: 当前批次中的节点索引列表
            :param labels: 节点标签
            :param feat_data: 节点特征数据
            :return: 一个包含(positive_pairs, negative_pairs)的元组
            """
            positive_pairs = []
            negative_pairs = []
            
            # 将CUDA张量转移到CPU，转换为NumPy数组
            if isinstance(labels, torch.Tensor) and labels.is_cuda:
                labels_cpu = labels.cpu().numpy()
            else:
                labels_cpu = labels

            # 确保batch_nodes也在CPU上
            if isinstance(batch_nodes, torch.Tensor) and batch_nodes.is_cuda:
                batch_nodes_cpu = batch_nodes.cpu().numpy()
            else:
                batch_nodes_cpu = batch_nodes

            for node in batch_nodes_cpu:
                # 正样本对：假设同类别节点作为正样本
                same_class_nodes = np.where(labels_cpu == labels_cpu[node])[0]
                if len(same_class_nodes) > 1:
                    pos_pair = np.random.choice(same_class_nodes[same_class_nodes != node], 1)[0]
                    positive_pairs.append((node, pos_pair))

                # 负样本对：随机选取不同类别的节点
                diff_class_nodes = np.where(labels_cpu != labels_cpu[node])[0]
                if len(diff_class_nodes) > 0:
                    neg_pair = np.random.choice(diff_class_nodes, 1)[0]
                    negative_pairs.append((node, neg_pair))
            
            return positive_pairs, negative_pairs



class GradientAwareFocalLoss(nn.Module):
    def __init__(self, num_classes, k_percent=10, gamma_focal=2.0, gamma_ga=0.5, gamma_grad=1.0, use_softmax=True):
        super(GradientAwareFocalLoss, self).__init__()
        self.num_classes = num_classes
        self.k_percent = k_percent
        self.gamma_focal = gamma_focal
        self.gamma_ga = gamma_ga
        self.gamma_grad = gamma_grad  # 控制梯度权重的强度
        self.use_softmax = use_softmax
        self.register_buffer('class_counts', torch.zeros(num_classes))
        self.register_buffer('class_weights', torch.ones(num_classes))

    def forward(self, inputs, targets):
        B, C = inputs.shape[:2]
        N = inputs.shape[2:].numel() * B  # 总样本数

        # 1. 计算概率和基础损失
        probs = F.softmax(inputs, dim=1) if self.use_softmax else inputs
        probs = probs.permute(0, *range(2, inputs.dim()), 1).contiguous().view(-1, C)
        targets = targets.view(-1)
        pt = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
        ce_loss = -torch.log(pt + 1e-8)

        # 2. 启用梯度计算（关键步骤！）
        inputs_grad = inputs.detach().requires_grad_(True)  # 保留梯度计算图
        probs_grad = F.softmax(inputs_grad, dim=1) if self.use_softmax else inputs_grad
        loss_grad = F.cross_entropy(probs_grad.view(-1, C), targets, reduction='none')
        grad_outputs = torch.ones_like(loss_grad)
        gradients = torch.autograd.grad(
            outputs=loss_grad,
            inputs=inputs_grad,
            grad_outputs=grad_outputs,
            create_graph=False,
            retain_graph=True  # 保留计算图以支持后续反向传播
        )[0]  # 梯度形状与inputs相同 (B, C, ...)

        # 3. 计算梯度幅度（L2范数）
        gradients = gradients.permute(0, *range(2, gradients.dim()), 1).contiguous().view(-1, C)
        grad_magnitude = gradients.norm(p=2, dim=1)  # (N_total,)
        grad_weight = (grad_magnitude + 1e-8) ** self.gamma_grad  # 避免零梯度

        # 4. 动态类别平衡（与原实现一致）
        num_topk = max(1, int(self.k_percent / 100 * N))
        _, topk_indices = torch.topk(ce_loss, num_topk, sorted=False)
        topk_targets = targets[topk_indices]
        current_counts = torch.bincount(topk_targets, minlength=self.num_classes).float()
        self.class_counts = 0.9 * self.class_counts + 0.1 * current_counts
        effective_counts = self.class_counts + 1e-8
        self.class_weights = (1.0 / effective_counts) ** (1.0 - self.gamma_ga)
        self.class_weights = self.class_weights / self.class_weights.sum() * C

        # 5. 三重权重耦合：Focal + Class + Gradient
        focal_weight = (1 - pt) ** self.gamma_focal
        class_weight = self.class_weights[targets]
        #final_weight = focal_weight * class_weight * grad_weight  # 关键融合点

        # step 1: class-aware difficulty
        difficulty_weight = class_weight * grad_weight
        difficulty_weight = difficulty_weight / (difficulty_weight.mean())

        # step 2: sample-level hardness (focal)
        final_weight = focal_weight * difficulty_weight
        final_weight = final_weight / (final_weight.mean())

        # 6. 最终损失
        loss = (final_weight * ce_loss).mean()
        return loss
    

# 添加LPL相关的类
def get_step(split: int, classes_num: int, pgd_nums: int, classes_freq: list):
    """计算每个类别的步数，基于类别频率"""
    step_size = pgd_nums*0.1
    class_step = []
    for i in range(0, classes_num):
        if i < split:
            step = (classes_freq[i] / classes_freq[0]) * step_size - 1
        else:
            step = (classes_freq[i] / classes_freq[-1]) * step_size - 1
        class_step.append(round(step))
    class_step = [0 if x < 0 else x for x in class_step]
    class_step = [pgd_nums+x for x in class_step]
    return class_step

class LPLLoss_advanced(nn.Module):
    def __init__(self, num_classes=2, pgd_nums=50, alpha=0.1, min_class_factor=3.0):
        """
        升级版自适应LPL损失实现
        
        Args:
            num_classes: 类别数量
            pgd_nums: 基础PGD扰动的步数
            alpha: 基础扰动强度
            min_class_factor: 少数类最小扰动系数，保证少数类扰动强度至少为多数类的这个倍数
        """
        super().__init__()
        self.num_classes = num_classes
        self.pgd_nums = pgd_nums
        self.alpha = alpha
        self.min_class_factor = min_class_factor
        self.criterion = nn.CrossEntropyLoss()
        
        # 记录类别不平衡和梯度状态
        self.register_buffer('class_counts', torch.zeros(num_classes))
        self.register_buffer('class_grad_mags', torch.zeros(num_classes))
        self.momentum = 0.9  # 动量因子
    
    def update_statistics(self, logit, y):
        """更新类别统计信息和梯度幅度"""
        with torch.no_grad():
            # 更新类别计数
            batch_counts = torch.bincount(y, minlength=self.num_classes).float()
            self.class_counts = self.momentum * self.class_counts + (1 - self.momentum) * batch_counts
            
            # 估计每个类别的梯度幅度
            grad_mags = torch.zeros(self.num_classes, device=logit.device)
            for c in range(self.num_classes):
                class_mask = (y == c)
                n_samples = torch.sum(class_mask)
                
                if n_samples > 0:
                    # 获取该类别样本的logits
                    class_logits = logit[class_mask]
                    class_targets = y[class_mask]
                    
                    # 计算样本损失，作为梯度幅度估计
                    ce_loss = F.cross_entropy(class_logits, class_targets, reduction='none')
                    grad_mags[c] = ce_loss.mean().item()
            
            # 使用动量更新梯度幅度
            self.class_grad_mags = self.momentum * self.class_grad_mags + (1 - self.momentum) * grad_mags

    def compute_adaptive_params(self, logit, y):
        """计算自适应扰动参数"""
        with torch.no_grad():
            # 更新统计信息
            self.update_statistics(logit, y)
            
            # 获取类别分布信息
            total_samples = torch.sum(self.class_counts)
            class_ratios = self.class_counts / (total_samples + 1e-8)
            
            # 找出少数类和多数类
            minority_idx = torch.argmin(class_ratios).item()
            majority_idx = 1 - minority_idx  # 在二分类情况下
            
            # 计算类别不平衡比
            imbalance_ratio = class_ratios[majority_idx] / (class_ratios[minority_idx] + 1e-8)

            imbalance_ratio_tensor = torch.tensor([imbalance_ratio], device=logit.device)
            imbalance_factor = torch.clamp(imbalance_ratio_tensor, 1.0, 10.0)
            
            # 根据梯度幅度动态调整扰动强度，梯度大的类别获得更强的扰动
            grad_scale = F.softmax(self.class_grad_mags, dim=0)
            
            # 类别步数和扰动强度
            class_steps = torch.zeros(self.num_classes, device=logit.device, dtype=torch.long)
            class_alphas = torch.zeros(self.num_classes, device=logit.device, dtype=torch.float)
            
            # 设置步数范围
            max_steps = int(self.pgd_nums * 2.0)
            min_steps = max(1, int(self.pgd_nums * 0.5))
            
            # 基于类别频率反比例计算步数
            for c in range(self.num_classes):
                # 样本越少，步数越多
                freq_factor = torch.sqrt(1.0 / (class_ratios[c] + 1e-8))
                steps = min_steps + int((max_steps - min_steps) * freq_factor / (freq_factor + 1.0))
                class_steps[c] = steps
                
                # 扰动强度：基于梯度幅度和类别频率
                alpha_base = self.alpha * (1.0 + grad_scale[c].item() * 2.0)  # 梯度大的类别获得更强的扰动
                
                # 少数类得到额外的强度提升
                if c == minority_idx:
                    alpha = alpha_base * min(5.0, imbalance_factor.item() ** 0.5)
                else:
                    alpha = alpha_base
                    
                class_alphas[c] = alpha
            
            # 确保少数类的步数至少是多数类的1.5倍
            if class_steps[minority_idx] < class_steps[majority_idx] * 1.5:
                class_steps[minority_idx] = int(class_steps[majority_idx] * 1.5)
            
            # 确保少数类的扰动强度至少是多数类的min_class_factor倍
            if class_alphas[minority_idx] < class_alphas[majority_idx] * self.min_class_factor:
                class_alphas[minority_idx] = class_alphas[majority_idx] * self.min_class_factor
            
            # 为每个样本分配步数和扰动强度
            sample_steps = torch.zeros_like(y, dtype=torch.long)
            sample_alphas = torch.zeros_like(y, dtype=torch.float)
            
            # 根据样本的类别分配参数
            for c in range(self.num_classes):
                class_mask = (y == c)
                sample_steps[class_mask] = class_steps[c]
                sample_alphas[class_mask] = class_alphas[c]
            
            # 样本级别的梯度感知调整
            with torch.enable_grad():
                # 创建副本并跟踪梯度
                logit_grad = logit.detach().clone().requires_grad_(True)
                loss = F.cross_entropy(logit_grad, y, reduction='none')
                
                # 计算梯度
                grads = torch.autograd.grad(
                    outputs=loss.sum(),
                    inputs=logit_grad,
                    create_graph=False,
                    retain_graph=False
                )[0]
                
                # 使用梯度幅度作为难度指标
                sample_grad_norms = torch.norm(grads, p=2, dim=1)
                sample_difficulties = F.softmax(sample_grad_norms, dim=0)
                
                # 将难度因子映射到[0.8, 1.5]的范围
                difficulty_scales = 0.8 + 0.7 * sample_difficulties / (torch.max(sample_difficulties) + 1e-8)
                
                # 应用到样本的扰动参数
                sample_alphas = sample_alphas * difficulty_scales
                
                # 步数也可以根据难度适当调整
                steps_difficulty_scales = 1.0 + 0.5 * sample_difficulties / (torch.max(sample_difficulties) + 1e-8)
                sample_steps = (sample_steps.float() * steps_difficulty_scales).long()
            
            return sample_steps, sample_alphas
    
    def compute_adv_sign(self, logit, y, sample_alphas):
        """计算自适应对抗梯度方向"""
        with torch.no_grad():
            logit_softmax = F.softmax(logit, dim=-1)
            y_onehot = F.one_hot(y, num_classes=self.num_classes)
            
            # 计算每个类别的平均logit
            sum_class_logit = torch.matmul(
                y_onehot.permute(1, 0)*1.0, logit_softmax)
            sum_class_num = torch.sum(y_onehot, dim=0)
            
            # 防止类别不存在导致除零
            sum_class_num = torch.where(sum_class_num == 0, 100, sum_class_num)
            mean_class_logit = torch.div(sum_class_logit, sum_class_num.reshape(-1, 1))
            
            # 计算扰动梯度方向
            grad = mean_class_logit - torch.eye(self.num_classes, device=logit.device)
            grad = torch.div(grad, torch.norm(grad, p=2, dim=0).reshape(-1, 1) + 1e-8)
            
            # 计算扰动方向标志
            mean_class_p = torch.diag(mean_class_logit)
            mean_mask = sum_class_num > 0
            mean_class_thr = torch.mean(mean_class_p[mean_mask])
            sub = mean_class_thr - mean_class_p
            sign = sub.sign()
            
            # 使用样本自适应扰动强度
            alphas_expanded = sample_alphas.unsqueeze(1).expand(-1, self.num_classes)
            adv_logit = torch.index_select(grad, 0, y) * alphas_expanded * sign[y].unsqueeze(1)
            
            return adv_logit, sub
    
    def compute_eta(self, logit, y):
        """计算最终的自适应扰动"""
        with torch.no_grad():
            # 计算自适应参数
            sample_steps, sample_alphas = self.compute_adaptive_params(logit, y)
            
            logit_clone = logit.clone()
            
            # 最大可能步数
            max_steps = torch.max(sample_steps).item()
            
            # 记录每步扰动后的结果
            logit_steps = torch.zeros(
                [max_steps + 1, logit.shape[0], self.num_classes], device=logit.device)
            
            # 初始状态
            current_logit = logit.clone()
            logit_steps[0] = current_logit
            
            # 迭代应用扰动
            for i in range(1, max_steps + 1):
                adv_logit, _ = self.compute_adv_sign(current_logit, y, sample_alphas)
                current_logit = current_logit + adv_logit
                logit_steps[i] = current_logit
            
            # 为每个样本选择对应步数的结果
            logit_news = torch.zeros_like(logit)
            for i in range(logit.shape[0]):
                step = sample_steps[i].item()
                logit_news[i] = logit_steps[step, i]
            
            # 计算扰动
            eta = logit_news - logit_clone
            
            return eta, sample_steps, sample_alphas
    
    def forward(self, models_or_logits, x=None, y=None, is_logits=False):
        """前向传播函数"""
        if is_logits:
            # 直接使用预计算的logits
            logit = models_or_logits
        else:
            # 使用模型计算logits
            logit = models_or_logits(x)
        
        # 计算自适应扰动
        eta, sample_steps, sample_alphas = self.compute_eta(logit, y)
        
        # 应用扰动
        logit_news = logit + eta
        
        # 计算损失
        loss_adv = self.criterion(logit_news, y)
        
        return loss_adv, logit, logit_news, sample_steps, sample_alphas
def calculate_g_mean(y_true, y_pred):
    cm = confusion_matrix(y_true, y_pred)
    sensitivities = []
    for i in range(len(cm)):
        TP = cm[i, i]
        FN = cm[i, :].sum() - TP
        sensitivity = TP / (TP + FN) if (TP + FN) != 0 else 0
        sensitivities.append(sensitivity)
    g_mean = np.prod(sensitivities) ** (1 / len(sensitivities))
    return g_mean
def evaluation(logits, y_eval):
    """评估函数，计算各种性能指标
    
    Args:
        logits: 模型预测的logits
        y_eval: 真实标签
        
    Returns:
        tuple: 包含AUC, AP, F1, G-Mean, ACC0, ACC1, ACC_overall的元组
    """
    # 确保logits是tensor类型
    if isinstance(logits, np.ndarray):
        logits = torch.tensor(logits, device='cuda:0')
    
    # 将标签转换为numpy数组以便处理
    y_eval_np = np.array(y_eval)
    
    # 过滤掉标签为2的样本（无标签数据）
    valid_mask = (y_eval_np == 0) | (y_eval_np == 1)
    if not np.all(valid_mask):
        y_eval_np = y_eval_np[valid_mask]
        logits = logits[valid_mask]
    
    # 确保输入验证
    if len(y_eval_np) == 0:
        print("警告: 过滤后没有有效样本进行评估")
        return 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
    
    # 获取正类概率
    x_softmax = torch.exp(logits).cpu().detach()
    positive_class_probs = x_softmax[:, 1].numpy()
    
    # 计算AUC
    auc_score = roc_auc_score(y_eval_np, np.array(positive_class_probs))
    
    # 获取标签0的概率 (1 - 标签1的概率)
    negative_class_probs = 1 - positive_class_probs
    
    # 创建标签0和标签1的二分类问题
    y_eval_label0 = (y_eval_np == 0).astype(int)  # 是否为标签0
    y_eval_label1 = (y_eval_np == 1).astype(int)  # 是否为标签1
    
    # 计算各个标签的AUC
    auc_score_label0 = roc_auc_score(y_eval_label0, negative_class_probs)
    auc_score_label1 = roc_auc_score(y_eval_label1, positive_class_probs)

    # 计算预测标签
    label_prob = (np.array(positive_class_probs) >= 0.5).astype(int)
    
    # 计算总体准确率
    acc_overall = accuracy_score(y_eval_np, label_prob)
    
    # 计算每个标签的准确率
    # 对于标签0：计算被正确预测为0的样本所占总标签0样本的比例
    if np.sum(y_eval_np == 0) > 0:  # 防止除以零
        acc_label0 = np.sum((y_eval_np == 0) & (label_prob == 0)) / np.sum(y_eval_np == 0)
    else:
        acc_label0 = 0.0
        
    # 对于标签1：计算被正确预测为1的样本所占总标签1样本的比例
    if np.sum(y_eval_np == 1) > 0:  # 防止除以零
        acc_label1 = np.sum((y_eval_np == 1) & (label_prob == 1)) / np.sum(y_eval_np == 1)
    else:
        acc_label1 = 0.0
    
    ap_score = average_precision_score(np.array(y_eval_np), np.array(positive_class_probs))
    f1_score_val = f1_score(np.array(y_eval_np), label_prob, average='macro')
    g_mean = calculate_g_mean(np.array(y_eval_np), label_prob)

    return auc_score, ap_score, f1_score_val, g_mean, acc_label0, acc_label1, acc_overall


    
    return f1_macro, auc, gmean, recall


if __name__ == '__main__':
    args = parse_args()
    setup_seed(72)  # 设置随机种子为72
    device = torch.device(args.cuda)
    args.device = device
    dataset_path = args.data_path+args.dataset+'.dgl'
    model_path = args.result_path+args.dataset+'_model.pt'
    log_path = args.result_path+args.dataset+'_log.json'
    results = {'F1-macro':[],'AUC':[],'G-Mean':[],'AP':[],'ACC1':[],'ACC0':[]}
    if not os.path.exists(args.result_path):
        os.makedirs(args.result_path)
    
    # 初始化混合精度训练所需的scaler
    scaler = GradScaler()
    
    '''
    # load dataset and normalize feature
    '''
    dataset = dgl.load_graphs(dataset_path)[0][0]
    features = dataset.ndata['feature'].numpy()
    features = normalize(features)
    dataset.ndata['feature'] = torch.from_numpy(features).float()
    
    # 打印节点标签的统计信息，检查标签是否合法
    all_labels = dataset.ndata['label']
    unique_labels = torch.unique(all_labels)
    print(f"数据集中的唯一标签值: {unique_labels}")
    for label in unique_labels:
        label_count = (all_labels == label).sum().item()
        print(f"标签 {label} 的数量: {label_count}")

    # 初始化GradientAwareFocalLoss损失函数
    gradient_aware_focal = GradientAwareFocalLoss(num_classes=2,
                                                  k_percent=10,
                                                  gamma_focal=2,
                                                  gamma_ga=0.5,
                                                  gamma_grad=2,
                                                  use_softmax=True).to(device)
    # 初始化自适应LPL损失函数
    adaptive_lpl_loss = LPLLoss_advanced(
        num_classes=2,
        pgd_nums=20,
        alpha=0.01,
        min_class_factor=3
    ).to(device)

    # 获取训练集中的正负样本索引
    train_mask = dataset.ndata['train_mask'].bool()
    train_labels = dataset.ndata['label'][train_mask]
    positive_indices = (train_labels == 1).nonzero().flatten()
    negative_indices = (train_labels == 0).nonzero().flatten()
    
    print(f"训练集中正样本数量: {len(positive_indices)}")
    print(f"训练集中负样本数量: {len(negative_indices)}")
    
    # 随机选择部分样本作为有标签样本（每类选择args.num_labeled个）
    num_labeled = getattr(args, 'num_labeled', 1)  # 默认每类1个
    print(f"每类选择 {num_labeled} 个有标签样本")
    
    # 确保不超过可用样本数量
    num_pos_labeled = min(num_labeled, len(positive_indices))
    num_neg_labeled = min(num_labeled, len(negative_indices))
    torch.manual_seed(1)
    if num_pos_labeled > 0:
        selected_positive = positive_indices[torch.randperm(len(positive_indices))[:num_pos_labeled]]
    else:
        selected_positive = torch.tensor([], dtype=torch.long)

    if num_neg_labeled > 0:
        selected_negative = negative_indices[torch.randperm(len(negative_indices))[:num_neg_labeled]]
    else:
        selected_negative = torch.tensor([], dtype=torch.long)

    labeled_nodes_tensor = torch.cat([selected_positive, selected_negative])
    print(f"有标签样本总数: {len(labeled_nodes_tensor)}")
    
    # 获取未被选中的样本索引
    remaining_positive = positive_indices[~torch.isin(positive_indices, selected_positive)]
    remaining_negative = negative_indices[~torch.isin(negative_indices, selected_negative)]
    unlabeled_nodes_tensor = torch.cat([remaining_positive, remaining_negative])
    print(f"无标签样本总数: {len(unlabeled_nodes_tensor)}")
    
    # 将数据集移至指定设备
    dataset = dataset.to(device)
    labeled_nodes_tensor = labeled_nodes_tensor.to(device)
    unlabeled_nodes_tensor = unlabeled_nodes_tensor.to(device)
    
    '''
    # train model
    '''
    print('Start training model...')
    model = H2FDetector(args, dataset)
    model = model.to(device)
    optimizer = optim.Adam(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    early_stop = EarlyStop(args.early_stop)
    
    valid_logs = []
    test_logs = [] # 用于记录每个epoch的测试结果
    if not hasattr(args, 'mu'):
        args.mu = 1.5  # 默认值
    if not hasattr(args, 'mu_rampup'):
        args.mu_rampup = True  # 默认启用rampup
    if not hasattr(args, 'consistency_rampup'):
        args.consistency_rampup = None  # 默认使用总epoch数
    if not hasattr(args, 'overwrite_viz'):
        args.overwrite_viz = True  # 默认覆盖之前的图片
    if not hasattr(args, 'clustering_temperature'):
        args.clustering_temperature = 0.8  # 默认聚类温度参数
    # 初始化最佳测试指标
    best_test_auc = 0.0
    best_test_f1_macro = 0.0
    best_test_gmean = 0.0
    best_test_ap = 0.0
    best_test_acc1 = 0.0
    best_test_acc0 = 0.0
    best_test_epoch = -1
    use_clustering_pseudo_labels = True
    fixed_cluster_epochs = 10
    use_original_pseudo_labels = True
    
    # 创建数据增强视图
    dateset_aug1 = dataset
    dateset_aug2 = dataset
    features1 = dateset_aug1.ndata['feature']
    features2 = dateset_aug2.ndata['feature']
    feat_aug1 = get_augmented_view(features1, drop_rate=0.2)
    feat_aug2 = get_augmented_view(features2, drop_rate=0.3)
    dateset_aug1.ndata['feature'] = feat_aug1
    dateset_aug2.ndata['feature'] = feat_aug2
    
    for e in range(args.epoch):
        model.train()
        epoch = e
        current_mu = get_current_mu(e, args)
        # 使用混合精度训练
        with autocast():
            try:

                original_out, original_h = model(dataset, return_hidden=True)

                

                out1, h1 = model(dateset_aug1, return_hidden=True)

                

                out2, h2 = model(dateset_aug2, return_hidden=True)

                
                # 计算有标签样本的分类损失
                batch_label = dataset.ndata['label'][labeled_nodes_tensor]
                # 检查标签的有效性
                if torch.any(batch_label >= 2):
                    print("警告: 发现标签值 >= 2，这可能导致nll_loss出错")
                    # 确保标签值小于类别数
                    batch_label = torch.clamp(batch_label, 0, 1)
                

                classification_loss_1 = F.nll_loss(out1[labeled_nodes_tensor], batch_label)
                classification_loss_2 = F.nll_loss(out2[labeled_nodes_tensor], batch_label)
                classification_loss = classification_loss_1 + classification_loss_2

                
                # 计算一致性损失 (使用所有样本)
                consistency_loss = F.mse_loss(h1, h2)

                
                positive_pairs, negative_pairs = generate_contrastive_pairs(labeled_nodes_tensor.tolist(), dataset.ndata['label'])
                # 确保索引是long类型的tensor
                if len(positive_pairs) > 0:
                    z_i_1 = h1[torch.tensor([p[0] for p in positive_pairs], dtype=torch.long, device=device)]
                    z_j_1 = h1[torch.tensor([p[1] for p in positive_pairs], dtype=torch.long, device=device)]
                    z_i_2 = h2[torch.tensor([p[0] for p in positive_pairs], dtype=torch.long, device=device)]
                    z_j_2 = h2[torch.tensor([p[1] for p in positive_pairs], dtype=torch.long, device=device)]
                    
                    contrastive_loss_1 = nt_xent_loss(z_i_1, z_j_1)
                    contrastive_loss_2 = nt_xent_loss(z_i_2, z_j_2)
                    contrastive_loss = contrastive_loss_1 + contrastive_loss_2
                else:
                    contrastive_loss = torch.tensor(0.0, device=device)
            except RuntimeError as e:
                if "CUDA out of memory" in str(e):
                    print(f"遇到CUDA内存不足错误: {e}")
                    print("尝试减少batch size或简化模型")
                    # 设置默认损失值，防止程序崩溃
                    classification_loss = torch.tensor(1.0, device=device, requires_grad=True)
                    consistency_loss = torch.tensor(0.0, device=device, requires_grad=True)
                    contrastive_loss = torch.tensor(0.0, device=device, requires_grad=True)
                    # 跳过后续的聚类和伪标签生成
                    clustering_loss = torch.tensor(0.0, device=device, requires_grad=True)
                    pseudo_label_loss = torch.tensor(0.0, device=device, requires_grad=True)
                    adap_lpl_loss = torch.tensor(0.0, device=device, requires_grad=True)
                    model_loss = torch.tensor(1.0, device=device, requires_grad=True)
                    
                    # 跳到损失计算和优化器步骤
                    current_mu = get_current_mu(e, args)
                    loss = classification_loss
                    
                    optimizer.zero_grad()
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                    
                    print(f"Epoch {e}使用替代损失训练完成")
                    continue  # 跳过本轮剩余部分
                else:
                    raise e  # 如果不是内存问题，则重新抛出异常

        # autocast之外的代码 - 处理聚类和伪标签
        h_orig_unlabeled = original_h[unlabeled_nodes_tensor].to(torch.float32)  # 确保类型一致
        h1_unlabeled = h1[unlabeled_nodes_tensor].to(torch.float32)  # 确保类型一致
        h2_unlabeled = h2[unlabeled_nodes_tensor].to(torch.float32)  # 确保类型一致
                
        # 获取有标签样本的特征，用于固定聚类中心
        labeled_features_orig = original_h[labeled_nodes_tensor].to(torch.float32)  # 确保类型一致
        labeled_classes = batch_label
        
        try:
            # 在前10个epoch中固定聚类中心为已有的正常和欺诈样本
            if epoch < fixed_cluster_epochs and use_clustering_pseudo_labels:
                        
                # 使用有标签数据初始化原始图的聚类中心，同时处理三个视图
                cluster_assignments_orig, cluster_assignments_view1, cluster_assignments_view2, centroids_orig = robust_node_clustering(
                    h_orig_unlabeled,  # 原始图特征
                    k=2, 
                    temperature=0.8,
                    max_iterations=10,
                    labeled_features=labeled_features_orig,  # 传入有标签样本特征
                    labeled_classes=labeled_classes       # 传入有标签样本类别
                )
                 
            elif use_clustering_pseudo_labels:
                    
                # 10个epoch后，让聚类算法自由寻找更好的聚类中心，同时处理三个视图
                cluster_assignments_orig, cluster_assignments_view1, cluster_assignments_view2, centroids_orig = robust_node_clustering(
                    h_orig_unlabeled,  # 原始图特征
                    k=2, 
                    temperature=0.8,
                    max_iterations=10
                )

            # 创建合并的特征和分配 - 确保类型一致
            all_features = torch.cat([h_orig_unlabeled, h1_unlabeled, h2_unlabeled], dim=0)
            all_assignments = torch.cat([cluster_assignments_orig, cluster_assignments_view1, cluster_assignments_view2], dim=0)
                    
            #计算统一的聚类损失
            clustering_loss, num_pos_all, num_neg_all = compute_clustering_loss(
                all_features, 
                all_assignments, 
                centroids_orig  # 使用原始图的聚类中心
            )
        except RuntimeError as e:
            print(f"聚类过程中出错: {e}")
            print("尝试使用统一的数据类型重新计算")
            # 再次尝试聚类，但确保所有张量都是相同的数据类型
            try:
                # 强制转换所有特征为相同类型
                all_unlabeled_features = [
                    h_orig_unlabeled.to(torch.float32),
                    h1_unlabeled.to(torch.float32),
                    h2_unlabeled.to(torch.float32)
                ]
                
                # 在前10个epoch中固定聚类中心为已有的正常和欺诈样本
                if epoch < fixed_cluster_epochs and use_clustering_pseudo_labels:
                    cluster_assignments_orig, _, _, centroids_orig = robust_node_clustering(
                        all_unlabeled_features[0],  # 原始图特征
                        k=2, 
                        temperature=0.8,
                        max_iterations=10,
                        labeled_features=labeled_features_orig.to(torch.float32),  # 传入有标签样本特征
                        labeled_classes=labeled_classes       # 传入有标签样本类别
                    )
                else:
                    cluster_assignments_orig, _, _, centroids_orig = robust_node_clustering(
                        all_unlabeled_features[0],  # 原始图特征
                        k=2, 
                        temperature=0.8,
                        max_iterations=10
                    )
                
                # 使用相同的聚类中心分配其他视图
                distances_1 = torch.cdist(all_unlabeled_features[1], centroids_orig)
                distances_2 = torch.cdist(all_unlabeled_features[2], centroids_orig)
                
                logits_1 = -distances_1 / 0.8
                logits_2 = -distances_2 / 0.8
                
                cluster_assignments_view1 = F.softmax(logits_1, dim=1)
                cluster_assignments_view2 = F.softmax(logits_2, dim=1)
                
                # 创建合并的特征和分配 - 确保类型一致
                all_features = torch.cat(all_unlabeled_features, dim=0)
                all_assignments = torch.cat([
                    cluster_assignments_orig, 
                    cluster_assignments_view1, 
                    cluster_assignments_view2
                ], dim=0)
                
                #计算统一的聚类损失
                clustering_loss, num_pos_all, num_neg_all = compute_clustering_loss(
                    all_features, 
                    all_assignments, 
                    centroids_orig  # 使用原始图的聚类中心
                )
            except Exception as e2:
                print(f"重试聚类仍然失败: {e2}")
                # 如果再次失败，则使用零聚类损失
                clustering_loss = torch.tensor(0.0, device=device, requires_grad=True)
        
        # 伪标签生成逻辑：取消置信度筛选，所有无标签样本均参与
        # 聚类结果：多数类为负样本(0)，少数类为正样本(1)
        # 概率对齐：确保用于融合的聚类概率列顺序为 [P(负), P(正)]
        with torch.no_grad():
            final_pseudo_labels_for_batch_unlabeled = torch.tensor([], dtype=torch.long, device=device)
                    
            # 确定伪标签来源和计算逻辑
            if use_original_pseudo_labels and use_clustering_pseudo_labels:
                # 场景1: 模型输出 + 聚类结果 融合
                orig_logits_unlabeled = original_out[unlabeled_nodes_tensor]
                orig_probs_unlabeled = F.softmax(orig_logits_unlabeled, dim=1) # 模型输出概率 [P(负), P(正)]

                # cluster_assignments_orig 是 [P(属聚类0), P(属聚类1)]
                # 确定聚类0和聚类1哪个是多数 (负)，哪个是少数 (正)
                temp_cluster_hard_labels = torch.argmax(cluster_assignments_orig, dim=1) # 初步判断样本属于哪个聚类
                count_c0 = torch.sum(temp_cluster_hard_labels == 0).item()
                count_c1 = torch.sum(temp_cluster_hard_labels == 1).item()
                        
                aligned_cluster_probs = cluster_assignments_orig.clone()
                if count_c0 < count_c1: 
                    # 聚类0是少数类 (应映射为正样本, 标签1)
                    # 聚类1是多数类 (应映射为负样本, 标签0)
                    # 调整列使 aligned_cluster_probs 为 [P(负=聚类1), P(正=聚类0)]
                    aligned_cluster_probs[:, 0] = cluster_assignments_orig[:, 1] # 负样本概率 = 原聚类1概率
                    aligned_cluster_probs[:, 1] = cluster_assignments_orig[:, 0] # 正样本概率 = 原聚类0概率
                # else: count_c0 >= count_c1
                    # 聚类0是多数类 (负样本, 标签0)
                    # 聚类1是少数类 (正样本, 标签1)
                    # aligned_cluster_probs 无需换列，已经是 [P(负=聚类0), P(正=聚类1)]
                        
                combined_probs_unlabeled = orig_probs_unlabeled
                final_pseudo_labels_for_batch_unlabeled = torch.argmax(combined_probs_unlabeled, dim=1)

            elif use_clustering_pseudo_labels:
                # 场景2: 仅使用聚类结果
                temp_cluster_hard_labels = torch.argmax(cluster_assignments_orig, dim=1) # 初步判断样本属于哪个聚类
                count_c0 = torch.sum(temp_cluster_hard_labels == 0).item()
                count_c1 = torch.sum(temp_cluster_hard_labels == 1).item()

                if count_c0 >= count_c1: 
                    # 聚类0是多数 (负样本=0), 聚类1是少数 (正样本=1)
                    # 伪标签与 temp_cluster_hard_labels 一致 (聚类0的为0, 聚类1的为1)
                    final_pseudo_labels_for_batch_unlabeled = temp_cluster_hard_labels
                else: 
                    # 聚类1是多数 (负样本=0), 聚类0是少数 (正样本=1)
                    # 伪标签与 temp_cluster_hard_labels 相反 (聚类0的为1, 聚类1的为0)
                    final_pseudo_labels_for_batch_unlabeled = 1 - temp_cluster_hard_labels
                    
            elif use_original_pseudo_labels:
                # 场景3: 仅使用模型输出
                orig_logits_unlabeled = original_out[unlabeled_nodes_tensor]
                final_pseudo_labels_for_batch_unlabeled = torch.argmax(orig_logits_unlabeled, dim=1)
                    
            # 为当前批次中所有无标签样本分配伪标签
            if final_pseudo_labels_for_batch_unlabeled.numel() > 0:
                # consistent_high_conf_indices 是相对于 unlabeled_nodes_tensor 的索引
                consistent_high_conf_indices = torch.arange(final_pseudo_labels_for_batch_unlabeled.size(0), device=device)
                consistent_pseudo_labels = final_pseudo_labels_for_batch_unlabeled
            else:
                consistent_high_conf_indices = torch.tensor([], dtype=torch.long, device=device)
                consistent_pseudo_labels = torch.tensor([], dtype=torch.long, device=device) # Ensure consistent_pseudo_labels is defined
                    
            # Define num_consistent_high_conf based on the number of pseudo-labels generated
            num_consistent_high_conf = consistent_pseudo_labels.numel()
        
        # 获取对应的预测
        if num_consistent_high_conf > 0:  # 确保有伪标签
            pseudo_logits_1 = out1[unlabeled_nodes_tensor][consistent_high_conf_indices]
            pseudo_logits_2 = out2[unlabeled_nodes_tensor][consistent_high_conf_indices]
            # 视图1
            pseudo_label_loss_1 = gradient_aware_focal(
                pseudo_logits_1, 
                consistent_pseudo_labels
            )
                                        
            # 视图2
            pseudo_label_loss_2 = gradient_aware_focal(
                pseudo_logits_2, 
                consistent_pseudo_labels
            )
            pseudo_label_loss = (pseudo_label_loss_1 + pseudo_label_loss_2) / 2

            # adap_lpl_loss_1, _, _, steps_1, alphas_1 = adaptive_lpl_loss(pseudo_logits_1, None, consistent_pseudo_labels, is_logits=True)
                                        
                            
            # adap_lpl_loss_2, _, _, steps_2, alphas_2 = adaptive_lpl_loss(pseudo_logits_2, None, consistent_pseudo_labels, is_logits=True)
            adap_lpl_loss_1 = torch.tensor(0.0, device=device)
            
            adap_lpl_loss = (adap_lpl_loss_1 + adap_lpl_loss_2) / 2
        else:
            pseudo_label_loss = torch.tensor(0.0, device=device)
            adap_lpl_loss = torch.tensor(0.0, device=device)
            

        # 使用异常处理捕获模型损失计算中的OOM错误
        try:
            model_loss = model.loss(dataset)
        except RuntimeError as e:
            if "CUDA out of memory" in str(e):
                print("模型损失计算OOM，使用替代损失")
                # 使用替代损失避免OOM
                model_loss = classification_loss * 0.5
            else:
                raise e
        

        # 总损失
        loss = model_loss
        
    
        # 使用GradScaler更新梯度
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        print(f"Epoch {e}训练完成，损失: {loss.item():.4f}")
        
        # 验证和测试部分
        try:
            with torch.no_grad():
                '''
                # valid
                '''
                model.eval()
                valid_mask = dataset.ndata['valid_mask'].bool()
                valid_labels_all = dataset.ndata['label'] # 使用全部标签来获取对应的valid_labels
                valid_labels = valid_labels_all[valid_mask].cpu().numpy()
                
                with autocast():
                    valid_logits = model(dataset)[valid_mask]
                    
                # log_softmax输出直接取argmax就可以，因为log是单调函数
                valid_preds = valid_logits.argmax(1).cpu().numpy()
                # 如果需要概率值，需要正确转换log概率
                valid_probs = torch.exp(valid_logits)  # 转换为概率分布
                
                # 计算验证集上的指标
                auc_score, ap_score, f1_score_val, g_mean, acc_label0, acc_label1, acc_overall = evaluation(
                    valid_logits, 
                    valid_labels
                )
                
                valid_log = {
                    'epoch': e,
                    'loss': loss.item(),
                    'auc': auc_score,
                    'ap': ap_score,
                    'f1': f1_score_val,
                    'g_mean': g_mean,
                    'acc1': acc_label1,
                    'acc0': acc_label0,
                    'acc_overall': acc_overall
                }
                valid_logs.append(valid_log)
                
                if args.log:
                    print(f'{e}: Valid Loss:{loss.item():.4f}, AUC:{auc_score:.4f}, AP:{ap_score:.4f}, F1:{f1_score_val:.4f}, G-Mean:{g_mean:.4f}, ACC1:{acc_label1:.4f}, ACC0:{acc_label0:.4f}, ACC_overall:{acc_overall:.4f}')

                '''
                # test model at each epoch
                '''
                test_mask = dataset.ndata['test_mask'].bool()
                test_labels_all = dataset.ndata['label'] # 使用全部标签
                test_labels_current = test_labels_all[test_mask].cpu().numpy() # 获取当前测试集标签
                
                with autocast():
                    test_logits = model(dataset)[test_mask]
                
                auc_score, ap_score, f1_score_val, g_mean, acc_label0, acc_label1, acc_overall = evaluation(
                    test_logits, 
                    test_labels_current
                )
                
                test_log = {
                    'epoch': e,
                    'auc': auc_score,
                    'ap': ap_score,
                    'f1': f1_score_val,
                    'g_mean': g_mean,
                    'acc1': acc_label1,
                    'acc0': acc_label0,
                    'acc_overall': acc_overall
                }
                test_logs.append(test_log)
                
                if args.log:
                    print(f'{e}: Test AUC:{auc_score:.4f}, AP:{ap_score:.4f}, F1:{f1_score_val:.4f}, G-Mean:{g_mean:.4f}, ACC1:{acc_label1:.4f}, ACC0:{acc_label0:.4f}, ACC_overall:{acc_overall:.4f}')

                # 更新最佳测试指标 (以AUC为主要标准)
                if auc_score > best_test_auc:
                    best_test_auc = auc_score
                    best_test_f1_macro = f1_score_val
                    best_test_gmean = g_mean
                    best_test_ap = ap_score
                    best_test_acc1 = acc_label1
                    best_test_acc0 = acc_label0
                    best_test_epoch = e
                    # 如果验证集的AUC也更好，也保存模型 (保持早停逻辑)
                    if auc_score > early_stop.best_eval:
                        torch.save(model, model_path)


                # 早停判断 (基于验证集)
                do_store, do_stop = early_stop.step(auc_score, e) # 使用验证集AUC进行早停
                if do_store and auc_score <= early_stop.best_eval: # 确保只有在验证集表现提升时才因为早停保存
                    torch.save(model, model_path) # 如果验证集表现好，也保存模型

                # if do_stop:
                #     print(f"Early stopping at epoch {e} based on validation AUC.")
                #     break
                    
        except RuntimeError as e:
            print(f"验证/测试过程中出错: {e}")
            print("继续下一个epoch的训练，跳过当前epoch的验证/测试")
            # 为防止早停逻辑出错，记录一个默认值
            if len(valid_logs) == 0 or e <= 5:  # 如果是前几个epoch出错，不触发早停
                early_stop.best_eval = 0
                early_stop.counter = 0
    
    print('End training')
    
    # 输出最佳测试结果
    print(f"===== 最佳测试结果 (第{best_test_epoch}轮) =====")
    print(f"最佳 AUC: {best_test_auc:.4f}")
    print(f"最佳 F1-macro: {best_test_f1_macro:.4f}")
    print(f"最佳 G-Mean: {best_test_gmean:.4f}")
    print(f"最佳 AP: {best_test_ap:.4f}")
    print(f"最佳 ACC1: {best_test_acc1:.4f}")
    print(f"最佳 ACC0: {best_test_acc0:.4f}")
    print("=====================================")
    
    '''
    # test model
    '''
    print('Test model...')
    model = torch.load(model_path)      
    with torch.no_grad():
        model.eval()
        test_mask = dataset.ndata['test_mask'].bool()
        test_labels = dataset.ndata['label'][test_mask]
        test_labels = test_labels.cpu().numpy()
        logits = model(dataset)[test_mask]
        
        # 修改这里：使用自定义的evaluation函数，而不是utils.py中的evaluate函数
        auc_score, ap_score, f1_macro, g_mean, acc_label0, acc_label1, acc_overall = evaluation(
            logits, 
            test_labels
        )
        
        # 更新结果字典
        results['F1-macro'].append(f1_macro)
        results['AUC'].append(auc_score)
        results['G-Mean'].append(g_mean)
        results['AP'].append(ap_score)
        results['ACC1'].append(acc_label1)
        results['ACC0'].append(acc_label0)
        
        print(f'Test: F1-macro:{f1_macro}, AUC:{auc_score}, G-Mean:{g_mean}, AP:{ap_score}, ACC1:{acc_label1}, ACC0:{acc_label0}')
    
    # 保存所有日志
    log_data = {
        'valid_logs': valid_logs,
        'test_logs': test_logs,  # 添加每个epoch的测试日志
        'test_results': results,
        'best_results': {
            'epoch': best_test_epoch,
            'auc': best_test_auc,
            'f1': best_test_f1_macro,
            'g_mean': best_test_gmean,
            'ap': best_test_ap,
            'acc1': best_test_acc1,
            'acc0': best_test_acc0
        }
    }
    
    with open(log_path, 'w') as f:
        json.dump(log_data, f, indent=4)
    

