import numpy as np
import dgl
import torch
import os
import logging
from sklearn.metrics import average_precision_score, f1_score, roc_auc_score, accuracy_score
import torch.optim as optim
from scipy.io import loadmat
import pandas as pd
import pickle
from sklearn.model_selection import StratifiedKFold, train_test_split
import torch.nn as nn
from sklearn.preprocessing import LabelEncoder, QuantileTransformer
from dgl.dataloading import MultiLayerFullNeighborSampler
from dgl.dataloading import DataLoader
from torch.optim.lr_scheduler import MultiStepLR
from .gtan_model import GraphAttnModel
from . import *
import torch.nn.functional as F
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from datetime import datetime
def test(logits, y_eval):
    """测试函数
    
    Args:
        idx_eval: 评估用的节点索引
        y_eval: 评估用的标签
        feat_data: 节点特征
        edge_indexs: 图结构
    """
    x_softmax = torch.exp(logits).cpu().detach()
    positive_class_probs = x_softmax[:, 1].numpy()
    # 计算总体AUC
    y_eval_np = np.array(y_eval)
    auc_score = roc_auc_score(y_eval_np, np.array(positive_class_probs))
    
    # 获取标签0的概率 (1 - 标签1的概率)
    negative_class_probs = 1 - positive_class_probs
    
    # 创建标签0和标签1的二分类问题
    y_eval_label0 = (y_eval_np == 0).astype(int)  # 是否为标签0
    y_eval_label1 = (y_eval_np == 1).astype(int)  # 是否为标签1
    
    # 计算各个标签的AUC
    auc_score_label0 = roc_auc_score(y_eval_label0, negative_class_probs)
    auc_score_label1 = roc_auc_score(y_eval_label1, positive_class_probs)

    # 计算预测标签
    label_prob = (np.array(positive_class_probs) >= 0.5).astype(int)
    
    # 计算总体准确率
    acc_overall = accuracy_score(y_eval_np, label_prob)
    
    # 计算每个标签的准确率
    # 对于标签0：计算被正确预测为0的样本所占总标签0样本的比例
    if np.sum(y_eval_np == 0) > 0:  # 防止除以零
        acc_label0 = np.sum((y_eval_np == 0) & (label_prob == 0)) / np.sum(y_eval_np == 0)
    else:
        acc_label0 = 0.0
        
    # 对于标签1：计算被正确预测为1的样本所占总标签1样本的比例
    if np.sum(y_eval_np == 1) > 0:  # 防止除以零
        acc_label1 = np.sum((y_eval_np == 1) & (label_prob == 1)) / np.sum(y_eval_np == 1)
    else:
        acc_label1 = 0.0
    
    ap_score = average_precision_score(np.array(y_eval), np.array(positive_class_probs))
    f1_score_val = f1_score(np.array(y_eval), label_prob, average='macro')
    g_mean = calculate_g_mean(np.array(y_eval), label_prob)

    return auc_score, ap_score, f1_score_val, g_mean, acc_label0, acc_label1, acc_overall
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def initialize_centroids(features, k):
    """使用k-means++策略初始化聚类中心"""
    num_nodes = features.size(0)
    centroids = torch.zeros(k, features.size(1), device=features.device)
    
    # 随机选择第一个中心
    first_id = torch.randint(num_nodes, (1,)).item()
    centroids[0] = features[first_id]
    
    # 选择剩余的中心
    for i in range(1, k):
        # 计算到最近中心的距离
        distances = torch.min(torch.cdist(features, centroids[:i]), dim=1)[0]
        # 按概率选择下一个中心
        probabilities = distances / distances.sum()
        next_id = torch.multinomial(probabilities, 1).item()
        centroids[i] = features[next_id]
    
    return centroids

def check_convergence(centroids, prev_centroids, tol=1e-4):
    """检查聚类是否收敛"""
    return torch.norm(centroids - prev_centroids) < tol
def robust_node_clustering(features, k=2, temperature=0.1, max_iterations=10, labeled_features=None, labeled_classes=None):
    """基于论文的鲁棒节点聚类方法
    
    Args:
        features: 原始图的节点特征 [num_nodes, feature_dim]
        k: 聚类数量(默认2，对应二分类)
        temperature: 温度参数，控制软分配的软硬程度
        max_iterations: 最大迭代次数
        labeled_features: 有标签样本的特征 [num_labeled, feature_dim]
        labeled_classes: 有标签样本的标签 [num_labeled]
    
    Returns:
        tuple: (
            original_cluster_assignments: 原始图的聚类分配 [num_nodes, k]
            view1_cluster_assignments: 增强视图1的聚类分配 [num_nodes, k]
            view2_cluster_assignments: 增强视图2的聚类分配 [num_nodes, k]
            centroids: 聚类中心 [k, feature_dim]
        )
    """
    num_nodes = features.size(0)
    feature_dim = features.size(1)
    device = features.device

    
    # 聚类迭代过程不需要梯度，使用no_grad包裹
    with torch.no_grad():
        # 检查是否提供了有标签样本作为聚类中心
        if labeled_features is not None and labeled_classes is not None:
            # 使用有标签样本初始化聚类中心
            centroids = torch.zeros(k, feature_dim, device=device)
            
            # 按类别分组有标签样本
            for i in range(k):
                # 找到标签为i的样本
                class_indices = torch.where(labeled_classes == i)[0]
                if len(class_indices) > 0:
                    # 如果有该类的样本，计算这些样本的平均特征作为中心
                    centroids[i] = labeled_features[class_indices].mean(dim=0)
                else:
                    # 如果没有该类的样本，随机初始化
                    centroids[i] = torch.randn(feature_dim, device=device)
                    centroids[i] = F.normalize(centroids[i], p=2, dim=0)  # 归一化
                    
            # 规范化聚类中心 - 确保它们具有相同的范数
            norms = torch.norm(centroids, dim=1, keepdim=True)
            centroids = centroids / (norms + 1e-10)  # 避免除以零
            
        else:
            # 如果没有提供有标签样本，使用原始的k-means++初始化策略
            # 注意：只使用原始图特征进行中心初始化
            centroids = initialize_centroids(features, k)
        
        # 记录初始的聚类中心用于检查收敛
        prev_centroids = centroids.clone()
        
        # 只有在没有提供标签数据时才进行迭代优化
        if labeled_features is None or labeled_classes is None:
            # 迭代优化 - 完全不需要梯度
            for iter in range(max_iterations):
                # 计算每个节点到各个聚类中心的距离 - 只使用原始图特征
                distances = torch.cdist(features, centroids)  # [num_nodes, k]
                
                # 软分配 (使用Gumbel-Softmax进行可微分的聚类分配)
                logits = -distances / temperature
                cluster_assignments = F.gumbel_softmax(logits, tau=temperature, hard=False)
                
                # 更新聚类中心 - 只使用原始图特征
                new_centroids = torch.zeros_like(centroids)
                for j in range(k):
                    weights = cluster_assignments[:, j].unsqueeze(1)  # [num_nodes, 1]
                    if weights.sum() > 0:  # 避免除以零
                        new_centroids[j] = (features * weights).sum(0) / weights.sum()
                    else:
                        new_centroids[j] = centroids[j].clone()  # 保持原来的中心
                
                # 使用新的张量替代原有张量
                centroids = new_centroids
                    
                # 检查收敛
                if check_convergence(centroids, prev_centroids, tol=1e-4):
                    break
                    
                prev_centroids = centroids.clone()
    
    # 重新计算最终的聚类分配（在梯度环境下使用不同视图的features，保留梯度）
    # 为原始图特征计算聚类分配
    distances_original = torch.cdist(features, centroids)  # [num_nodes, k]
    logits_original = -distances_original / temperature
    original_cluster_assignments = F.gumbel_softmax(logits_original, tau=temperature, hard=False)
    
   
    view1_cluster_assignments = original_cluster_assignments

    
    view2_cluster_assignments = original_cluster_assignments

    # 计算聚类结果的统计信息
    with torch.no_grad():
        hard_assignments = torch.argmax(original_cluster_assignments, dim=1)
        num_class_0 = torch.sum(hard_assignments == 0).item()
        num_class_1 = torch.sum(hard_assignments == 1).item()
        total = num_class_0 + num_class_1
    
    return original_cluster_assignments, view1_cluster_assignments, view2_cluster_assignments, centroids
def compute_clustering_loss(features, cluster_assignments, centroids, epsilon=1e-6):
    features = F.normalize(features, p=2, dim=1)
    centroids = F.normalize(centroids, p=2, dim=1)
    
    with torch.no_grad():
        hard_assignments = torch.argmax(cluster_assignments, dim=1)
        pos_indices = torch.nonzero(hard_assignments == 1).squeeze(-1)
        neg_indices = torch.nonzero(hard_assignments == 0).squeeze(-1)
        num_pos = pos_indices.numel()
        num_neg = neg_indices.numel()
        total = num_pos + num_neg + epsilon
        pos_weight = num_neg / total if num_pos > 0 else 0.0
        neg_weight = num_pos / total if num_neg > 0 else 0.0

    distances = torch.cdist(features, centroids)  # [N, K]
    intra_positive_loss = torch.mean(distances[pos_indices, 1]) if num_pos > 0 else torch.tensor(0.0, device=features.device)
    intra_negative_loss = torch.mean(distances[neg_indices, 0]) if num_neg > 0 else torch.tensor(0.0, device=features.device)
    intra_loss =  intra_positive_loss +  intra_negative_loss

    centroid_dists = torch.pdist(centroids)
    inter_loss = -torch.mean(centroid_dists)

    expanded_centroids = torch.index_select(centroids, 0, hard_assignments)
    compactness = torch.mean(torch.sum((features - expanded_centroids) ** 2, dim=1))
    joint_reg = compactness / (torch.mean(centroid_dists) + epsilon)

    total_loss = 0.5 * intra_loss + 0.5 * inter_loss + 0.1 * joint_reg
    return total_loss, num_pos, num_neg

def nt_xent_loss(z_i, z_j, temperature=0.01):
            """
            NT-Xent Loss (Normalised Temperature-scaled Cross Entropy Loss)
            
            :param z_i: Tensor, representations of the first augmented view.
            :param z_j: Tensor, representations of the second augmented view.
            :param temperature: Float, temperature scaling factor for the loss function.
            """
            # Normalize the feature vectors
            z_i = F.normalize(z_i, dim=-1)
            z_j = F.normalize(z_j, dim=-1)
            
            # Concatenate the features from both views
            representations = torch.cat([z_i, z_j], dim=0)
            
            # Compute similarity matrix
            sim_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1)
            
            # Create labels for positive and negative pairs
            labels = torch.cat([torch.arange(z_i.size(0)).to(device) for _ in range(2)], dim=0)
            masks = labels[:, None] == labels[None, :]
            
            # Mask out self-similarity terms
            mask_diag = ~torch.eye(labels.size(0), dtype=torch.bool).to(device)
            sim_matrix = sim_matrix[mask_diag].view(labels.size(0), -1)
            masks = masks[mask_diag].view(labels.size(0), -1)
            
            # Compute the InfoNCE loss
            nominator = torch.exp(sim_matrix / temperature)[masks].view(labels.size(0), -1).sum(dim=-1)
            denominator = torch.sum(torch.exp(sim_matrix / temperature), dim=-1)
            loss = -torch.log(nominator / denominator).mean()
            
            return loss

def generate_contrastive_pairs(batch_nodes, labels):
            """
            根据给定的batch nodes生成正样本对和负样本对。
            
            :param batch_nodes: 当前批次中的节点索引列表
            :param labels: 节点标签
            :return: 一个包含(positive_pairs, negative_pairs)的元组
            """
            positive_pairs = []
            negative_pairs = []
            
            # 将CUDA张量转移到CPU，转换为NumPy数组
            if isinstance(labels, torch.Tensor) and labels.is_cuda:
                labels_cpu = labels.cpu().numpy()
            else:
                labels_cpu = labels

            # 确保batch_nodes也在CPU上
            if isinstance(batch_nodes, torch.Tensor) and batch_nodes.is_cuda:
                batch_nodes_cpu = batch_nodes.cpu().numpy()
            else:
                batch_nodes_cpu = batch_nodes

            # 将batch_nodes转换为集合以便快速查找
            batch_nodes_set = set(batch_nodes_cpu)
            batch_nodes_list = list(batch_nodes_cpu)
            
            # 获取batch内节点的标签
            batch_labels = labels_cpu[batch_nodes_cpu]

            for i, node in enumerate(batch_nodes_list):
                node_label = batch_labels[i]
                # 在batch内找同类别的节点
                same_class_indices = [j for j, label in enumerate(batch_labels) 
                                   if label == node_label and batch_nodes_list[j] != node]
                
                if same_class_indices:  # 如果存在同类别节点
                    pos_idx = np.random.choice(same_class_indices)
                    positive_pairs.append((i, pos_idx))  # 使用batch内的索引

                # 在batch内找不同类别的节点
                diff_class_indices = [j for j, label in enumerate(batch_labels) 
                                   if label != node_label]
                
                if diff_class_indices:  # 如果存在不同类别节点
                    neg_idx = np.random.choice(diff_class_indices)
                    negative_pairs.append((i, neg_idx))  # 使用batch内的索引
            
            return positive_pairs, negative_pairs

def get_augmented_view(edge_indexs, feat_data, aug_type, drop_rate=0.2):
    """获取指定类型的图增强视图，适配HOGRL的多层图结构
    Args:
        edge_indexs: 原始图的多层边索引
        feat_data: 节点特征
        aug_type: 增强类型 ['edge_drop', 'feat_drop', 'degree', 'pr', 'weighted_feat']
        drop_rate: 删除比例
    Returns:
        如果是边增强: 返回增强后的多层图结构
        如果是特征增强: 返回 (原始边索引, 增强后的特征)
    """
    if aug_type == 'feat_drop':
        # 特征删除
        feat_mask = torch.rand(feat_data.size(1)) > drop_rate
        feat_aug = feat_data.clone()
        feat_aug[:, ~feat_mask] = 0
        return edge_indexs, feat_aug
        
    elif aug_type == 'weighted_feat':
        # 加权特征删除
        node_deg = degree(edge_indexs[0][0][1])  # 使用第一个关系的主图计算节点度
        feat_weights = feature_drop_weights(feat_data, node_deg)
        feat_aug = drop_feature_weighted(feat_data, feat_weights, drop_rate)
        return edge_indexs, feat_aug
    
    # 以下是边增强的逻辑
    augmented_edge_indexs = []
    
    for i, edge_index in enumerate(edge_indexs):
        if aug_type == 'edge_drop':
            # 随机边删除
            edge_mask = torch.rand(edge_index[0].size(1)) > drop_rate
            edge_index_main = edge_index[0][:, edge_mask]
            edge_index_trees = [tree_edge[:, torch.rand(tree_edge.size(1)) > drop_rate] 
                              for tree_edge in edge_index[1]]
                
        elif aug_type == 'degree':
            # 基于度的加权边删除
            drop_weights = degree_drop_weights(edge_index[0])
            edge_index_main = drop_edge_weighted(edge_index[0], drop_weights, p=drop_rate)
            
            edge_index_trees = []
            for tree_edge in edge_index[1]:
                tree_weights = degree_drop_weights(tree_edge)
                edge_index_trees.append(drop_edge_weighted(tree_edge, tree_weights, p=drop_rate))
                
        elif aug_type == 'pr':
            # PageRank加权边删除
            drop_weights = pr_drop_weights(edge_index[0], aggr='sink', k=10)
            edge_index_main = drop_edge_weighted(edge_index[0], drop_weights, p=drop_rate)
            
            edge_index_trees = []
            for tree_edge in edge_index[1]:
                tree_weights = pr_drop_weights(tree_edge, aggr='sink', k=10)
                edge_index_trees.append(drop_edge_weighted(tree_edge, tree_weights, p=drop_rate))
                
        else:
            raise ValueError(f"不支持的增强类型: {aug_type}")
            
        augmented_edge_indexs.append([edge_index_main, edge_index_trees])
    
    return feat_data, augmented_edge_indexs



mu_rampup = True

consistency_rampup = None
def sigmoid_rampup(current, rampup_length):
    '''Exponential rampup from https://arxiv.org/abs/1610.02242'''
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))
def get_current_mu(epoch):
    mu = 1.5
    if mu_rampup:
        # Consistency ramp-up from https://arxiv.org/abs/1610.02242
        consistency_rampup = 500
        return mu * sigmoid_rampup(epoch, consistency_rampup)
    else:
        return mu

# 计算几何平均值Gmean
def geometric_mean(recall_0, recall_1):
    return np.sqrt(recall_0 * recall_1)

# 计算G-mean
def calculate_g_mean(y_true, y_pred):
    pos_indices = (y_true == 1)
    neg_indices = (y_true == 0)
    
    recall_pos = np.mean(y_pred[pos_indices] == y_true[pos_indices]) if np.any(pos_indices) else 0
    recall_neg = np.mean(y_pred[neg_indices] == y_true[neg_indices]) if np.any(neg_indices) else 0
    
    return geometric_mean(recall_neg, recall_pos)

# 添加 GradientAwareFocalLoss 类
class GradientAwareFocalLoss(nn.Module):
    def __init__(self, num_classes, k_percent=10, gamma_focal=2.0, gamma_ga=0.5, gamma_grad=1.0, use_softmax=True):
        super(GradientAwareFocalLoss, self).__init__()
        self.num_classes = num_classes
        self.k_percent = k_percent
        self.gamma_focal = gamma_focal
        self.gamma_ga = gamma_ga
        self.gamma_grad = gamma_grad
        self.use_softmax = use_softmax
        self.register_buffer('class_counts', torch.zeros(num_classes))
        self.register_buffer('class_weights', torch.ones(num_classes))

    def forward(self, inputs, targets):
        B, C = inputs.shape[:2]
        N = inputs.shape[2:].numel() * B

        probs = F.softmax(inputs, dim=1) if self.use_softmax else inputs
        probs = probs.permute(0, *range(2, inputs.dim()), 1).contiguous().view(-1, C)
        targets = targets.view(-1)
        pt = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
        ce_loss = -torch.log(pt + 1e-8)

        inputs_grad = inputs.detach().requires_grad_(True)
        probs_grad = F.softmax(inputs_grad, dim=1) if self.use_softmax else inputs_grad
        loss_grad = F.cross_entropy(probs_grad.view(-1, C), targets, reduction='none')
        grad_outputs = torch.ones_like(loss_grad)
        gradients = torch.autograd.grad(
            outputs=loss_grad,
            inputs=inputs_grad,
            grad_outputs=grad_outputs,
            create_graph=False,
            retain_graph=True
        )[0]

        gradients = gradients.permute(0, *range(2, gradients.dim()), 1).contiguous().view(-1, C)
        grad_magnitude = gradients.norm(p=2, dim=1)
        grad_weight = (grad_magnitude + 1e-8) ** self.gamma_grad

        num_topk = max(1, int(self.k_percent / 100 * N))
        _, topk_indices = torch.topk(ce_loss, num_topk, sorted=False)
        topk_targets = targets[topk_indices]
        current_counts = torch.bincount(topk_targets, minlength=self.num_classes).float()
        self.class_counts = 0.9 * self.class_counts + 0.1 * current_counts
        effective_counts = self.class_counts + 1e-8
        self.class_weights = (1.0 / effective_counts) ** (1.0 - self.gamma_ga)
        self.class_weights = self.class_weights / self.class_weights.sum() * C

        focal_weight = (1 - pt) ** self.gamma_focal
        class_weight = self.class_weights[targets]

        difficulty_weight = class_weight * grad_weight
        difficulty_weight = difficulty_weight / (difficulty_weight.mean())

        final_weight = focal_weight * difficulty_weight
        final_weight = final_weight / (final_weight.mean())

        loss = (final_weight * ce_loss).mean()
        return loss

# 添加 LPLLoss_advanced 类
class LPLLoss_advanced(nn.Module):
    def __init__(self, num_classes=2, pgd_nums=50, alpha=0.1, min_class_factor=3.0):
        super().__init__()
        self.num_classes = num_classes
        self.pgd_nums = pgd_nums
        self.alpha = alpha
        self.min_class_factor = min_class_factor
        self.criterion = nn.CrossEntropyLoss()
        
        self.register_buffer('class_counts', torch.zeros(num_classes))
        self.register_buffer('class_grad_mags', torch.zeros(num_classes))
        self.momentum = 0.9

    def update_statistics(self, logit, y):
        with torch.no_grad():
            batch_counts = torch.bincount(y, minlength=self.num_classes).float()
            self.class_counts = self.momentum * self.class_counts + (1 - self.momentum) * batch_counts
            
            grad_mags = torch.zeros(self.num_classes, device=logit.device)
            for c in range(self.num_classes):
                class_mask = (y == c)
                if torch.sum(class_mask) > 0:
                    class_logits = logit[class_mask]
                    class_targets = y[class_mask]
                    ce_loss = F.cross_entropy(class_logits, class_targets, reduction='none')
                    grad_mags[c] = ce_loss.mean().item()
            
            self.class_grad_mags = self.momentum * self.class_grad_mags + (1 - self.momentum) * grad_mags

    def compute_adaptive_params(self, logit, y):
        with torch.no_grad():
            self.update_statistics(logit, y)
            
            total_samples = torch.sum(self.class_counts)
            class_ratios = self.class_counts / (total_samples + 1e-8)
            
            minority_idx = torch.argmin(class_ratios).item()
            majority_idx = 1 - minority_idx
            
            imbalance_ratio = class_ratios[majority_idx] / (class_ratios[minority_idx] + 1e-8)
            imbalance_ratio_tensor = torch.tensor([imbalance_ratio], device=logit.device)
            imbalance_factor = torch.clamp(imbalance_ratio_tensor, 1.0, 10.0)
            
            grad_scale = F.softmax(self.class_grad_mags, dim=0)
            
            class_steps = torch.zeros(self.num_classes, device=logit.device, dtype=torch.long)
            class_alphas = torch.zeros(self.num_classes, device=logit.device, dtype=torch.float)
            
            max_steps = int(self.pgd_nums * 2.0)
            min_steps = max(1, int(self.pgd_nums * 0.5))
            
            for c in range(self.num_classes):
                freq_factor = torch.sqrt(1.0 / (class_ratios[c] + 1e-8))
                steps = min_steps + int((max_steps - min_steps) * freq_factor / (freq_factor + 1.0))
                class_steps[c] = steps
                
                alpha_base = self.alpha * (1.0 + grad_scale[c].item() * 2.0)
                
                if c == minority_idx:
                    alpha = alpha_base * min(5.0, imbalance_factor.item() ** 0.5)
                else:
                    alpha = alpha_base
                    
                class_alphas[c] = alpha
            
            if class_steps[minority_idx] < class_steps[majority_idx] * 1.5:
                class_steps[minority_idx] = int(class_steps[majority_idx] * 1.5)
            
            if class_alphas[minority_idx] < class_alphas[majority_idx] * self.min_class_factor:
                class_alphas[minority_idx] = class_alphas[majority_idx] * self.min_class_factor
            
            sample_steps = torch.zeros_like(y, dtype=torch.long)
            sample_alphas = torch.zeros_like(y, dtype=torch.float)
            
            for c in range(self.num_classes):
                class_mask = (y == c)
                sample_steps[class_mask] = class_steps[c]
                sample_alphas[class_mask] = class_alphas[c]
            
            with torch.enable_grad():
                logit_grad = logit.detach().clone().requires_grad_(True)
                loss = F.cross_entropy(logit_grad, y, reduction='none')
                
                grads = torch.autograd.grad(
                    outputs=loss.sum(),
                    inputs=logit_grad,
                    create_graph=False,
                    retain_graph=False
                )[0]
                
                sample_grad_norms = torch.norm(grads, p=2, dim=1)
                sample_difficulties = F.softmax(sample_grad_norms, dim=0)
                
                difficulty_scales = 0.8 + 0.7 * sample_difficulties / (torch.max(sample_difficulties) + 1e-8)
                
                sample_alphas = sample_alphas * difficulty_scales
                
                steps_difficulty_scales = 1.0 + 0.5 * sample_difficulties / (torch.max(sample_difficulties) + 1e-8)
                sample_steps = (sample_steps.float() * steps_difficulty_scales).long()
            
            return sample_steps, sample_alphas

    def compute_adv_sign(self, logit, y, sample_alphas):
        with torch.no_grad():
            logit_softmax = F.softmax(logit, dim=-1)
            y_onehot = F.one_hot(y, num_classes=self.num_classes)
            
            sum_class_logit = torch.matmul(y_onehot.permute(1, 0)*1.0, logit_softmax)
            sum_class_num = torch.sum(y_onehot, dim=0)
            
            sum_class_num = torch.where(sum_class_num == 0, 100, sum_class_num)
            mean_class_logit = torch.div(sum_class_logit, sum_class_num.reshape(-1, 1))
            
            grad = mean_class_logit - torch.eye(self.num_classes, device=logit.device)
            grad = torch.div(grad, torch.norm(grad, p=2, dim=0).reshape(-1, 1) + 1e-8)
            
            mean_class_p = torch.diag(mean_class_logit)
            mean_mask = sum_class_num > 0
            mean_class_thr = torch.mean(mean_class_p[mean_mask])
            sub = mean_class_thr - mean_class_p
            sign = sub.sign()
            
            alphas_expanded = sample_alphas.unsqueeze(1).expand(-1, self.num_classes)
            adv_logit = torch.index_select(grad, 0, y) * alphas_expanded * sign[y].unsqueeze(1)
            
            return adv_logit, sub

    def compute_eta(self, logit, y):
        with torch.no_grad():
            sample_steps, sample_alphas = self.compute_adaptive_params(logit, y)
            
            logit_clone = logit.clone()
            
            max_steps = torch.max(sample_steps).item()
            
            logit_steps = torch.zeros(
                [max_steps + 1, logit.shape[0], self.num_classes], device=logit.device)
            
            current_logit = logit.clone()
            logit_steps[0] = current_logit
            
            for i in range(1, max_steps + 1):
                adv_logit, _ = self.compute_adv_sign(current_logit, y, sample_alphas)
                current_logit = current_logit + adv_logit
                logit_steps[i] = current_logit
            
            logit_news = torch.zeros_like(logit)
            for i in range(logit.shape[0]):
                step = sample_steps[i].item()
                logit_news[i] = logit_steps[step, i]
            
            eta = logit_news - logit_clone
            
            return eta, sample_steps, sample_alphas

    def forward(self, models_or_logits, x=None, y=None, is_logits=False):
        if is_logits:
            logit = models_or_logits
        else:
            logit = models_or_logits(x)
        
        eta, sample_steps, sample_alphas = self.compute_eta(logit, y)
        
        logit_news = logit + eta
        
        loss_adv = self.criterion(logit_news, y)
        
        return loss_adv, logit, logit_news, sample_steps, sample_alphas

def gtan_main(feat_df, graph, train_idx, test_idx, labels, args, cat_features):
    # 设置随机种子为72
    args['seed'] = 64
    np.random.seed(args['seed'])
    torch.manual_seed(args['seed'])
    torch.cuda.manual_seed_all(args['seed'])
    
    # 设置日志
    log_dir = os.path.join(os.path.dirname(__file__), "..", "..", "logs")
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    log_file = os.path.join(log_dir, f"gtan_log_{args.get('dataset', 'unknown')}_seed{args['seed']}.txt")
    logging.basicConfig(filename=log_file, level=logging.INFO, 
                        format='%(asctime)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger('').addHandler(console)
    
    device = args['device']
    logging.info(f'Device: {device}')
    graph = graph.to(device)
    oof_predictions = torch.from_numpy(
        np.zeros([len(feat_df), 2])).float().to(device)
    test_predictions = torch.from_numpy(
        np.zeros([len(feat_df), 2])).float().to(device)
    kfold = StratifiedKFold(
        n_splits=args['n_fold'], shuffle=True, random_state=args['seed'])

    y_target = labels.iloc[train_idx].values
    num_feat = torch.from_numpy(feat_df.values).float().to(device)
    cat_feat = {col: torch.from_numpy(feat_df[col].values).long().to(
        device) for col in cat_features}

    y = labels
    labels = torch.from_numpy(y.values).long().to(device)
    loss_fn = nn.CrossEntropyLoss().to(device)

    # 初始化损失函数
    gradient_aware_focal = GradientAwareFocalLoss(num_classes=2,
                                              k_percent=10,
                                              gamma_focal=2,
                                              gamma_ga=0.5,
                                              gamma_grad=1,
                                              use_softmax=True).to(device)
    
    adaptive_lpl_loss = LPLLoss_advanced(
        num_classes=2,
        pgd_nums=30,
        alpha=0.05,
        min_class_factor=3.5
    ).to(device)

    # 添加最佳test指标的跟踪变量
    best_test_metrics = {
        'auc': 0,
        'f1': 0,
        'ap': 0,
        'acc1': 0,
        'acc0': 0,
        'gmean': 0,
        'epoch': 0
    }

    for fold, (trn_idx, val_idx) in enumerate(kfold.split(feat_df.iloc[train_idx], y_target)):
        logging.info(f'Training fold {fold + 1}')
        
        # 原始训练索引
        original_trn_ind = np.array(train_idx)[trn_idx]
        
        # 划分正负样本
        pos_samples = [i for i in original_trn_ind if y.iloc[i] == 1]
        neg_samples = [i for i in original_trn_ind if y.iloc[i] == 0]
        
        # 如果正样本或负样本数量不足，记录警告
        if len(pos_samples) == 0:
            logging.warning("训练集中没有正样本，无法选择一个正样本")
            pos_samples = []
        if len(neg_samples) == 0:
            logging.warning("训练集中没有负样本，无法选择一个负样本")
            neg_samples = []
        
        # 选择一个正样本和一个负样本
        selected_pos = [pos_samples[0]] if len(pos_samples) > 0 else []
        selected_neg = [neg_samples[0]] if len(neg_samples) > 0 else []
        
        # 新的训练集包含一个正样本和一个负样本，其余样本标签设为2
        labeled_samples = selected_pos + selected_neg
        trn_ind_list = labeled_samples.copy()
        
        # 将剩余的有标签样本转为无标签样本(label=2)
        remaining_pos = pos_samples[1:] if len(pos_samples) > 1 else []
        remaining_neg = neg_samples[1:] if len(neg_samples) > 1 else []
        unlabeled_samples = remaining_pos + remaining_neg
        
        # 创建无标签样本的mask，而不是修改原始标签
        unlabeled_mask = torch.zeros_like(labels)
        unlabeled_mask[unlabeled_samples] = 1
        
        # 将无标签样本加入到训练集,并确保每个batch都包含labeled样本
        batch_size = args['batch_size']
        num_labeled = len(labeled_samples)
        num_unlabeled_per_batch = batch_size - num_labeled
        
        # 将unlabeled samples分成多个batch
        num_full_batches = len(unlabeled_samples) // num_unlabeled_per_batch
        
        # 重新组织训练索引列表,确保每个batch都包含labeled samples
        final_trn_ind_list = []
        for i in range(num_full_batches):
            start_idx = i * num_unlabeled_per_batch
            end_idx = start_idx + num_unlabeled_per_batch
            batch_unlabeled = unlabeled_samples[start_idx:end_idx]
            final_trn_ind_list.extend(labeled_samples + batch_unlabeled)
            
        # 处理剩余的unlabeled samples
        remaining_start = num_full_batches * num_unlabeled_per_batch
        if remaining_start < len(unlabeled_samples):
            remaining_unlabeled = unlabeled_samples[remaining_start:]
            if len(remaining_unlabeled) > 0:
                final_trn_ind_list.extend(labeled_samples + remaining_unlabeled)
        
        logging.info(f'训练集正样本数: {len(selected_pos)}, 负样本数: {len(selected_neg)}, 无标签样本数: {len(unlabeled_samples)}')
        
        trn_ind = torch.tensor(final_trn_ind_list).long().to(device)
        val_ind = torch.from_numpy(np.array(train_idx)[val_idx]).long().to(device)
        
        logging.info(f'训练/验证/测试样本数: {len(trn_ind)}, {len(val_ind)}, {len(test_idx)}')

        train_sampler = MultiLayerFullNeighborSampler(args['n_layers'])
        train_dataloader = DataLoader(graph,
                                  trn_ind,
                                  train_sampler,
                                  device=device,
                                  use_ddp=False,
                                  batch_size=args['batch_size'],
                                  shuffle=False,  # 不需要shuffle,因为我们已经组织好了数据
                                  drop_last=False,
                                  num_workers=0
                                  )
        val_sampler = MultiLayerFullNeighborSampler(args['n_layers'])
        val_dataloader = DataLoader(graph,
                                    val_ind,
                                    val_sampler,
                                    use_ddp=False,
                                    device=device,
                                    batch_size=args['batch_size'],
                                    shuffle=True,
                                    drop_last=False,
                                    num_workers=0,
                                    )
        # TODO
        model = GraphAttnModel(in_feats=feat_df.shape[1],
                               # 为什么要整除4？
                               hidden_dim=args['hid_dim']//4,
                               n_classes=2,
                               heads=[4]*args['n_layers'],  # [4,4,4]
                               activation=nn.PReLU(),
                               n_layers=args['n_layers'],
                               drop=args['dropout'],
                               device=device,
                               gated=args['gated'],
                               ref_df=feat_df,
                               cat_features=cat_feat).to(device)
        lr = args['lr'] * np.sqrt(args['batch_size']/1024)  # 0.00075
        optimizer = optim.Adam(model.parameters(), lr=lr,
                               weight_decay=args['wd'])
        lr_scheduler = MultiStepLR(optimizer=optimizer, milestones=[
                                   4000, 12000], gamma=0.3)

        earlystoper = early_stopper(
            patience=args['early_stopping'], verbose=False)
        start_epoch, max_epochs = 0, 2000

        fixed_cluster_epochs = 10
        use_clustering_pseudo_labels = True
        use_original_pseudo_labels = True
        
        for epoch in range(start_epoch, args['max_epochs']):
            epoch_pseudo_pos_count =0
            epoch_pseudo_neg_count =0
            train_loss_list = []
            current_mu = get_current_mu(epoch)
            model.train()
            for step, (input_nodes, seeds, blocks) in enumerate(train_dataloader):
                batch_inputs, batch_work_inputs, batch_labels, lpa_labels = load_lpa_subtensor(num_feat, cat_feat, labels,
                                                                                               seeds, input_nodes, device)
                # (|input|, feat_dim); null; (|batch|,); (|input|,)
                blocks = [block.to(device) for block in blocks]
                
                # 生成两个特征增强视图
                _, feat_aug1 = get_augmented_view(None, batch_inputs, aug_type='feat_drop', drop_rate=0.1)
                _, feat_aug2 = get_augmented_view(None, batch_inputs, aug_type='feat_drop', drop_rate=0.1)
                # feat_aug1 = batch_inputs
                # feat_aug2 = batch_inputs
                # 获取原始视图和两个增强视图的输出
                out_orig, h_orig = model(blocks, batch_inputs, lpa_labels, batch_work_inputs)
                out_aug1, h_aug1 = model(blocks, feat_aug1, lpa_labels, batch_work_inputs)
                out_aug2, h_aug2 = model(blocks, feat_aug2, lpa_labels, batch_work_inputs)
                
                # 在训练时使用mask来识别无标签样本
                mask = unlabeled_mask[seeds] == 1
                train_batch_logits = out_orig[~mask]
                batch_labels = labels[seeds][~mask]  # 使用原始标签
                out1 = out_aug1[~mask]
                out2 = out_aug2[~mask]
                # 分类损失
                classification_loss1 = F.nll_loss(out1, batch_labels)
                classification_loss2 = F.nll_loss(out2, batch_labels)
                classification_loss = (classification_loss1 + classification_loss2) / 2
                
                # 对比学习损失
                # 只对有标签数据生成正负样本对
                batch_labeled = seeds[~mask]  # 获取有标签数据的索引
                

                positive_pairs, negative_pairs = generate_contrastive_pairs(batch_labeled, labels)
                
                if len(positive_pairs) > 0:
                    z_i_1 = h_aug1[torch.tensor([p[0] for p in positive_pairs], dtype=torch.long, device=device)]
                    z_j_1 = h_aug1[torch.tensor([p[1] for p in positive_pairs], dtype=torch.long, device=device)]
                    contrastive_loss_1 = nt_xent_loss(z_i_1, z_j_1)
                        
                    # 在第二个增强视图内计算对比损失
                    z_i_2 = h_aug2[torch.tensor([p[0] for p in positive_pairs], dtype=torch.long, device=device)]
                    z_j_2 = h_aug2[torch.tensor([p[1] for p in positive_pairs], dtype=torch.long, device=device)]
                    contrastive_loss_2 = nt_xent_loss(z_i_2, z_j_2)
                        
                    # 取两个视图的平均对比损失
                    contrastive_loss = (contrastive_loss_1 + contrastive_loss_2) / 2
                else:
                    contrastive_loss = torch.tensor(0.0).to(device)
                
                #consistency loss
                consistency_loss = F.mse_loss(h_aug1, h_aug2)
                

                h_orig_unlabeled = h_orig[mask]
                h1_unlabeled = h_aug1[mask]
                h2_unlabeled = h_aug2[mask]
                # 获取有标签样本的特征，用于固定聚类中心
             
                labeled_features_orig = h_orig[~mask]
                labeled_classes = batch_labeled

                # 在前10个epoch中固定聚类中心为已有的正常和欺诈样本
                if epoch < fixed_cluster_epochs and use_clustering_pseudo_labels:
                    
                    # 使用有标签数据初始化原始图的聚类中心，同时处理三个视图
                    cluster_assignments_orig, cluster_assignments_view1, cluster_assignments_view2, centroids_orig = robust_node_clustering(
                        h_orig_unlabeled,  # 原始图特征
                        k=2, 
                        temperature=0.8,
                        max_iterations=10,
                        labeled_features=labeled_features_orig,  # 传入有标签样本特征
                        labeled_classes=labeled_classes       # 传入有标签样本类别
                    )
             
                elif use_clustering_pseudo_labels:
                    # 在第10个epoch时记录转为自由聚类的信息
                    
                    # 10个epoch后，让聚类算法自由寻找更好的聚类中心，同时处理三个视图
                    cluster_assignments_orig, cluster_assignments_view1, cluster_assignments_view2, centroids_orig = robust_node_clustering(
                        h_orig_unlabeled,  # 原始图特征
                        k=2, 
                        temperature=0.8,
                        max_iterations=10
                    )
                # 创建合并的特征和分配
                all_features = torch.cat([h_orig_unlabeled, h1_unlabeled, h2_unlabeled], dim=0)
                all_assignments = torch.cat([cluster_assignments_orig, cluster_assignments_view1, cluster_assignments_view2], dim=0)
                
                #计算统一的聚类损失
                clustering_loss, num_pos_all, num_neg_all = compute_clustering_loss(
                    all_features, 
                    all_assignments, 
                    centroids_orig  # 使用原始图的聚类中心
                )

                
                with torch.no_grad():
                    final_pseudo_labels_for_batch_unlabeled = torch.tensor([], dtype=torch.long, device=device)
                    
                    # 确定伪标签来源和计算逻辑
                    if use_original_pseudo_labels and use_clustering_pseudo_labels:
                        # 场景1: 模型输出 + 聚类结果 融合
                        orig_logits_unlabeled = out_orig[mask]
                        orig_probs_unlabeled = F.softmax(orig_logits_unlabeled, dim=1) # 模型输出概率 [P(负), P(正)]

                        # cluster_assignments_orig 是 [P(属聚类0), P(属聚类1)]
                        # 确定聚类0和聚类1哪个是多数 (负)，哪个是少数 (正)
                        temp_cluster_hard_labels = torch.argmax(cluster_assignments_orig, dim=1) # 初步判断样本属于哪个聚类
                        count_c0 = torch.sum(temp_cluster_hard_labels == 0).item()
                        count_c1 = torch.sum(temp_cluster_hard_labels == 1).item()
                        
                        aligned_cluster_probs = cluster_assignments_orig.clone()
                        if count_c0 < count_c1: 
                            # 聚类0是少数类 (应映射为正样本, 标签1)
                            # 聚类1是多数类 (应映射为负样本, 标签0)
                            # 调整列使 aligned_cluster_probs 为 [P(负=聚类1), P(正=聚类0)]
                            aligned_cluster_probs[:, 0] = cluster_assignments_orig[:, 1] # 负样本概率 = 原聚类1概率
                            aligned_cluster_probs[:, 1] = cluster_assignments_orig[:, 0] # 正样本概率 = 原聚类0概率
                        # else: count_c0 >= count_c1
                            # 聚类0是多数类 (负样本, 标签0)
                            # 聚类1是少数类 (正样本, 标签1)
                            # aligned_cluster_probs 无需换列，已经是 [P(负=聚类0), P(正=聚类1)]
                        
                        combined_probs_unlabeled = (orig_probs_unlabeled + aligned_cluster_probs) / 2.0
                        final_pseudo_labels_for_batch_unlabeled = torch.argmax(combined_probs_unlabeled, dim=1)

                    elif use_clustering_pseudo_labels:
                        # 场景2: 仅使用聚类结果
                        temp_cluster_hard_labels = torch.argmax(cluster_assignments_orig, dim=1) # 初步判断样本属于哪个聚类
                        count_c0 = torch.sum(temp_cluster_hard_labels == 0).item()
                        count_c1 = torch.sum(temp_cluster_hard_labels == 1).item()

                        if count_c0 >= count_c1: 
                            # 聚类0是多数 (负样本=0), 聚类1是少数 (正样本=1)
                            # 伪标签与 temp_cluster_hard_labels 一致 (聚类0的为0, 聚类1的为1)
                            final_pseudo_labels_for_batch_unlabeled = temp_cluster_hard_labels
                        else: 
                            # 聚类1是多数 (负样本=0), 聚类0是少数 (正样本=1)
                            # 伪标签与 temp_cluster_hard_labels 相反 (聚类0的为1, 聚类1的为0)
                            final_pseudo_labels_for_batch_unlabeled = 1 - temp_cluster_hard_labels
                    
                    elif use_original_pseudo_labels:
                        # 场景3: 仅使用模型输出
                        orig_logits_unlabeled = out_orig[mask]
                        final_pseudo_labels_for_batch_unlabeled = torch.argmax(orig_logits_unlabeled, dim=1)
                    
                    # 为当前批次中所有无标签样本分配伪标签
                    if final_pseudo_labels_for_batch_unlabeled.numel() > 0:
                        # consistent_high_conf_indices 是相对于 unlabeled_nodes_tensor 的索引
                        consistent_high_conf_indices = torch.arange(final_pseudo_labels_for_batch_unlabeled.size(0), device=device)
                        consistent_pseudo_labels = final_pseudo_labels_for_batch_unlabeled
                        
                        # 累积伪标签正负样本数量
                        epoch_pseudo_pos_count += torch.sum(consistent_pseudo_labels == 1).item()
                        epoch_pseudo_neg_count += torch.sum(consistent_pseudo_labels == 0).item()
                    else:
                        consistent_high_conf_indices = torch.tensor([], dtype=torch.long, device=device)
                        consistent_pseudo_labels = torch.tensor([], dtype=torch.long, device=device) # Ensure consistent_pseudo_labels is defined
                    
                    # Define num_consistent_high_conf based on the number of pseudo-labels generated
                    num_consistent_high_conf = consistent_pseudo_labels.numel()
                
                pseudo_logits_1 = out_aug1[mask][consistent_high_conf_indices]
                pseudo_logits_2 = out_aug2[mask][consistent_high_conf_indices]
                # 根据伪标签策略选择不同的损失函数
                if epoch < -1:
                    # 仅使用原始图输出时，使用普通的交叉熵损失
                    pseudo_label_loss_1 = F.cross_entropy(pseudo_logits_1, consistent_pseudo_labels)
                    pseudo_label_loss_2 = F.cross_entropy(pseudo_logits_2, consistent_pseudo_labels)
                else:
                    #使用GradientAwareFocalLoss

                    # 视图1
                    pseudo_label_loss_1 = gradient_aware_focal(
                        pseudo_logits_1, 
                        consistent_pseudo_labels
                    )
                                    
                    # 视图2
                    pseudo_label_loss_2 = gradient_aware_focal(
                        pseudo_logits_2, 
                        consistent_pseudo_labels
                    )
                
                pseudo_label_loss = (pseudo_label_loss_1 + pseudo_label_loss_2) / 2
                

                adap_lpl_loss_1, _, _, steps_1, alphas_1 = adaptive_lpl_loss(pseudo_logits_1, None, consistent_pseudo_labels, is_logits=True)
                adap_lpl_loss_2, _, _, steps_2, alphas_2 = adaptive_lpl_loss(pseudo_logits_2, None, consistent_pseudo_labels, is_logits=True)
                pseudo_lpl_loss = (adap_lpl_loss_1 + adap_lpl_loss_2) / 2

                # 总损失
                train_loss = classification_loss + contrastive_loss + current_mu * consistency_loss + \
                current_mu * pseudo_label_loss + current_mu * clustering_loss + current_mu * pseudo_lpl_loss

                
                # backward
                optimizer.zero_grad()
                train_loss.backward()
                optimizer.step()
                lr_scheduler.step()
                train_loss_list.append(train_loss.cpu().detach().numpy())

            # 在每个epoch结束时输出训练集指标
            tr_pred = torch.sum(torch.argmax(train_batch_logits.clone().detach(), dim=1) == batch_labels) / batch_labels.shape[0]
            score = torch.softmax(train_batch_logits.clone().detach(), dim=1)[:, 1].cpu().numpy()
            
            pred_labels = torch.argmax(train_batch_logits.clone().detach(), dim=1).cpu().numpy()
            batch_labels_np = batch_labels.cpu().numpy()
            
            # 计算正负样本准确率
            pos_indices = (batch_labels_np == 1)
            neg_indices = (batch_labels_np == 0)
            
            train_acc1 = np.mean(pred_labels[pos_indices] == batch_labels_np[pos_indices]) if np.any(pos_indices) else 0.0
            train_acc0 = np.mean(pred_labels[neg_indices] == batch_labels_np[neg_indices]) if np.any(neg_indices) else 0.0
            
            # 计算G-mean
            train_gmean = calculate_g_mean(batch_labels_np, pred_labels)

            # 输出训练集指标
            try:
                log_msg = ('Epoch {:03d}, train_loss:{:4f}, '
                            'train_ap:{:.4f}, train_acc:{:.4f}, train_auc:{:.4f}, '
                            'train_acc1:{:.4f}, train_acc0:{:.4f}, train_gmean:{:.4f}')
                
                logging.info(log_msg.format(epoch,
                                            np.mean(train_loss_list),
                                            average_precision_score(batch_labels.cpu().numpy(), score),
                                            tr_pred.detach(),
                                            roc_auc_score(batch_labels.cpu().numpy(), score),
                                            train_acc1, train_acc0, train_gmean))
                
                # 输出伪标签统计信息
                if epoch_pseudo_pos_count > 0 or epoch_pseudo_neg_count > 0:
                    pseudo_total = epoch_pseudo_pos_count + epoch_pseudo_neg_count
                    logging.info(f'Epoch {epoch} 伪标签统计: 正样本 {epoch_pseudo_pos_count} ({epoch_pseudo_pos_count/pseudo_total:.2%}), '
                                f'负样本 {epoch_pseudo_neg_count} ({epoch_pseudo_neg_count/pseudo_total:.2%}), '
                                f'总计 {pseudo_total}')
            except Exception as e:
                logging.error(f"Error calculating metrics: {e}")

            # mini-batch for validation
            val_loss_list = 0
            val_acc_list = 0
            val_all_list = 0
            val_batch_all_preds = []
            val_batch_all_labels = []
            val_batch_all_scores = []
            model.eval()
            with torch.no_grad():
                for step, (input_nodes, seeds, blocks) in enumerate(val_dataloader):
                    batch_inputs, batch_work_inputs, batch_labels, lpa_labels = load_lpa_subtensor(num_feat, cat_feat, labels,
                                                                                               seeds, input_nodes, device)

                    blocks = [block.to(device) for block in blocks]
                    val_batch_logits, _ = model(
                        blocks, batch_inputs, lpa_labels, batch_work_inputs)
                    oof_predictions[seeds] = torch.exp(val_batch_logits)  # 转换回概率
                    mask = batch_labels == 2
                    val_batch_logits = val_batch_logits[~mask]
                    batch_labels = batch_labels[~mask]
                    val_loss_list = val_loss_list + \
                        F.nll_loss(val_batch_logits, batch_labels)  # 使用nll_loss
                    val_batch_pred = torch.sum(torch.argmax(
                        val_batch_logits, dim=1) == batch_labels) / torch.tensor(batch_labels.shape[0])
                    val_acc_list = val_acc_list + val_batch_pred * \
                        torch.tensor(batch_labels.shape[0])
                    val_all_list = val_all_list + batch_labels.shape[0]
                    
                    # 收集预测和标签用于计算整体指标
                    pred_labels = torch.argmax(val_batch_logits, dim=1).cpu().numpy()
                    batch_labels_np = batch_labels.cpu().numpy()
                    val_batch_all_preds.append(pred_labels)
                    val_batch_all_labels.append(batch_labels_np)
                    score = torch.exp(val_batch_logits)[:, 1].cpu().numpy()  # 转换回概率
                    val_batch_all_scores.append(score)
                
                # 计算整体验证集指标
                if len(val_batch_all_labels) > 0:
                    all_val_labels = np.concatenate(val_batch_all_labels)
                    
                    # 直接收集模型输出的logits用于test函数
                    val_logits_list = []
                    with torch.no_grad():
                        for step, (input_nodes, seeds, blocks) in enumerate(val_dataloader):
                            batch_inputs, batch_work_inputs, batch_labels, lpa_labels = load_lpa_subtensor(num_feat, cat_feat, labels,
                                                                                                seeds, input_nodes, device)
                            blocks = [block.to(device) for block in blocks]
                            val_batch_logits, _ = model(blocks, batch_inputs, lpa_labels, batch_work_inputs)
                            
                            mask = batch_labels == 2
                            val_batch_logits = val_batch_logits[~mask]
                            batch_labels = batch_labels[~mask]
                            
                            val_logits_list.append(val_batch_logits.cpu())
                    
                    # 合并所有logits
                    if val_logits_list:
                        all_val_logits = torch.cat(val_logits_list, dim=0)
                        
                        # 直接将收集到的logits传递给test函数
                        val_auc, val_ap, val_f1, val_gmean, val_acc0, val_acc1, val_acc = test(all_val_logits, all_val_labels)
                    
                    # 输出整体验证集指标
                    log_msg = ('Epoch {:03d} validation: val_loss:{:4f}, val_ap:{:.4f}, '
                              'val_acc:{:.4f}, val_auc:{:.4f}, val_acc1:{:.4f}, val_acc0:{:.4f}, val_gmean:{:.4f}, val_f1:{:.4f}')
                    
                    logging.info(log_msg.format(epoch,
                                                val_loss_list/val_all_list,
                                                val_ap,
                                                val_acc,
                                                val_auc,
                                                val_acc1, val_acc0, val_gmean, val_f1))

            # 在每个epoch结束时评估测试集
            if epoch == 0:
                # 第一个epoch时初始化测试dataloader
                test_ind = torch.from_numpy(np.array(test_idx)).long().to(device)
                test_sampler = MultiLayerFullNeighborSampler(args['n_layers'])
                test_dataloader = DataLoader(graph,
                                         test_ind,
                                         test_sampler,
                                         use_ddp=False,
                                         device=device,
                                         batch_size=args['batch_size'],
                                         shuffle=True,
                                         drop_last=False,
                                         num_workers=0,
                                         )
            
            # 在每个epoch结束时评估测试集
            current_model = model.to(device)
            current_model.eval()
            test_batch_all_labels = []
            test_logits_list = []
            
            with torch.no_grad():
                for step, (input_nodes, seeds, blocks) in enumerate(test_dataloader):
                    batch_inputs, batch_work_inputs, batch_labels, lpa_labels = load_lpa_subtensor(num_feat, cat_feat, labels,
                                                                                               seeds, input_nodes, device)
                    blocks = [block.to(device) for block in blocks]
                    test_batch_logits, _ = current_model(blocks, batch_inputs, lpa_labels, batch_work_inputs)
                    
                    # 直接收集logits和标签
                    batch_labels_np = batch_labels.cpu().numpy()
                    test_batch_all_labels.append(batch_labels_np)
                    test_logits_list.append(test_batch_logits.cpu())
                    
                    # 存储test预测结果，用于最后的最佳模型评估
                    if hasattr(earlystoper, 'is_best') and earlystoper.is_best:
                        test_predictions[seeds] = test_batch_logits
            
            # 计算当前epoch的test指标
            if len(test_batch_all_labels) > 0:
                all_test_labels = np.concatenate(test_batch_all_labels)
                
                # 合并所有logits
                all_test_logits = torch.cat(test_logits_list, dim=0)
                
                mask = all_test_labels != 2  # 排除无标签样本
                all_test_labels = all_test_labels[mask]
                all_test_logits = all_test_logits[mask]
                
                # 直接将收集到的logits传递给test函数
                test_auc, test_ap, test_f1, test_gmean, test_acc0, test_acc1, test_acc = test(all_test_logits, all_test_labels)
                
                # 记录当前测试指标
                current_test_metrics = {}
                current_test_metrics['auc'] = test_auc
                current_test_metrics['f1'] = test_f1
                current_test_metrics['ap'] = test_ap
                current_test_metrics['acc1'] = test_acc1
                current_test_metrics['acc0'] = test_acc0
                current_test_metrics['gmean'] = test_gmean
                current_test_metrics['acc'] = test_acc
                current_test_metrics['epoch'] = epoch
                
                # 输出当前epoch的test指标
                logging.info(f'Epoch {epoch} test metrics:')
                logging.info(f'Test AUC: {test_auc:.4f}')
                logging.info(f'Test F1: {test_f1:.4f}')
                logging.info(f'Test AP: {test_ap:.4f}')
                logging.info(f'Test ACC: {test_acc:.4f}')
                logging.info(f'Test ACC1: {test_acc1:.4f}')
                logging.info(f'Test ACC0: {test_acc0:.4f}')
                logging.info(f'Test G-mean: {test_gmean:.4f}')
                
                # 更新最佳指标
                if current_test_metrics['auc'] > best_test_metrics['auc']:
                    best_test_metrics = current_test_metrics.copy()
                    logging.info(f'New best test metrics at epoch {epoch}!')
                
                # 每个epoch都输出截至目前最好的测试指标
                logging.info(f'Best test metrics so far (at epoch {best_test_metrics["epoch"]}):')
                logging.info(f'Best AUC: {best_test_metrics["auc"]:.4f}')
                logging.info(f'Best F1: {best_test_metrics["f1"]:.4f}')
                logging.info(f'Best AP: {best_test_metrics["ap"]:.4f}')
                logging.info(f'Best ACC1: {best_test_metrics["acc1"]:.4f}')
                logging.info(f'Best ACC0: {best_test_metrics["acc0"]:.4f}')
                logging.info(f'Best G-mean: {best_test_metrics["gmean"]:.4f}')

            earlystoper.earlystop(val_loss_list/val_all_list, model)
            if earlystoper.is_earlystop:
                logging.info("Early Stopping!")
                break

        logging.info("Best val_loss is: {:.7f}".format(earlystoper.best_cv))
        
        # 此处不再进行最终测试，因为已经在每个epoch中测试过了
        
        # 在最后统计并输出最佳test指标
        logging.info("\nBest test metrics (at epoch {}):".format(best_test_metrics['epoch']))
        logging.info("Best test AUC: {:.4f}".format(best_test_metrics['auc']))
        logging.info("Best test F1: {:.4f}".format(best_test_metrics['f1']))
        logging.info("Best test AP: {:.4f}".format(best_test_metrics['ap']))
        logging.info("Best test ACC1: {:.4f}".format(best_test_metrics['acc1']))
        logging.info("Best test ACC0: {:.4f}".format(best_test_metrics['acc0']))
        logging.info("Best test G-mean: {:.4f}".format(best_test_metrics['gmean']))
        
        # 仍然计算并输出out-of-fold AP
        logging.info("\nFinal metrics summary:")
        mask = y_target == 2
        y_target[mask] = 0
        my_ap = average_precision_score(y_target, torch.softmax(oof_predictions, dim=1).cpu()[train_idx, 1])
        logging.info("NN out of fold AP is: {:.4f}".format(my_ap))

    return feat_data, labels, train_idx, test_idx, g, cat_features

def load_gtan_data(dataset: str, test_size: float):
    """
    Load graph, feature, and label given dataset name
    :param dataset: the dataset name
    :param test_size: the size of test set
    :returns: feature, label, graph, category features
    """
    # prefix = './antifraud/data/'
    prefix = os.path.join(os.path.dirname(__file__), "..", "..", "data/")
    if dataset == "S-FFSD":
        cat_features = ["Target", "Location", "Type"]

        df = pd.read_csv(prefix + "S-FFSDneofull.csv")
        df = df.loc[:, ~df.columns.str.contains('Unnamed')]
        data = df[df["Labels"] <= 2]
        data = data.reset_index(drop=True)
        out = []
        alls = []
        allt = []
        pair = ["Source", "Target", "Location", "Type"]
        for column in pair:
            src, tgt = [], []
            edge_per_trans = 3
            for c_id, c_df in data.groupby(column):
                c_df = c_df.sort_values(by="Time")
                df_len = len(c_df)
                sorted_idxs = c_df.index
                src.extend([sorted_idxs[i] for i in range(df_len)
                            for j in range(edge_per_trans) if i + j < df_len])
                tgt.extend([sorted_idxs[i+j] for i in range(df_len)
                            for j in range(edge_per_trans) if i + j < df_len])
            alls.extend(src)
            allt.extend(tgt)
        alls = np.array(alls)
        allt = np.array(allt)
        g = dgl.graph((alls, allt))

        cal_list = ["Source", "Target", "Location", "Type"]
        for col in cal_list:
            le = LabelEncoder()
            data[col] = le.fit_transform(data[col].apply(str).values)
        feat_data = data.drop("Labels", axis=1)
        labels = data["Labels"]
        ###
        feat_data.to_csv(prefix + "S-FFSD_feat_data.csv", index=None)
        labels.to_csv(prefix + "S-FFSD_label_data.csv", index=None)
        ###
        index = list(range(len(labels)))
        g.ndata['label'] = torch.from_numpy(
            labels.to_numpy()).to(torch.long)
        g.ndata['feat'] = torch.from_numpy(
            feat_data.to_numpy()).to(torch.float32)
        graph_path = prefix+"graph-{}.bin".format(dataset)
        dgl.data.utils.save_graphs(graph_path, [g])

        train_idx, test_idx, y_train, y_test = train_test_split(index, labels, stratify=labels, test_size=test_size/2,
                                                                random_state=72, shuffle=True)

    elif dataset == "yelp":
        cat_features = []
        data_file = loadmat(prefix + 'YelpChi.mat')
        labels = pd.DataFrame(data_file['label'].flatten())[0]
        feat_data = pd.DataFrame(data_file['features'].todense().A)
        # load the preprocessed adj_lists
        with open(prefix + 'yelp_homo_adjlists.pickle', 'rb') as file:
            homo = pickle.load(file)
        file.close()
        index = list(range(len(labels)))
        train_idx, test_idx, y_train, y_test = train_test_split(index, labels, stratify=labels, test_size=test_size,
                                                                random_state=72, shuffle=True)
        src = []
        tgt = []
        for i in homo:
            for j in homo[i]:
                src.append(i)  # src是出发点
                tgt.append(j)  # tgt是被指向点
        src = np.array(src)
        tgt = np.array(tgt)
        g = dgl.graph((src, tgt))
        g.ndata['label'] = torch.from_numpy(labels.to_numpy()).to(torch.long)
        g.ndata['feat'] = torch.from_numpy(
            feat_data.to_numpy()).to(torch.float32)
        graph_path = prefix + "graph-{}.bin".format(dataset)
        dgl.data.utils.save_graphs(graph_path, [g])

    elif dataset == "amazon":
        cat_features = []
        data_file = loadmat(prefix + 'Amazon.mat')
        labels = pd.DataFrame(data_file['label'].flatten())[0]
        feat_data = pd.DataFrame(data_file['features'].todense().A)
        # load the preprocessed adj_lists
        with open(prefix + 'amz_homo_adjlists.pickle', 'rb') as file:
            homo = pickle.load(file)
        file.close()
        index = list(range(3305, len(labels)))
        train_idx, test_idx, y_train, y_test = train_test_split(index, labels[3305:], stratify=labels[3305:],
                                                                test_size=test_size, random_state=72, shuffle=True)
        src = []
        tgt = []
        for i in homo:
            for j in homo[i]:
                src.append(i)
                tgt.append(j)
        src = np.array(src)
        tgt = np.array(tgt)
        g = dgl.graph((src, tgt))
        g.ndata['label'] = torch.from_numpy(labels.to_numpy()).to(torch.long)
        g.ndata['feat'] = torch.from_numpy(
            feat_data.to_numpy()).to(torch.float32)
        graph_path = prefix + "graph-{}.bin".format(dataset)
        dgl.data.utils.save_graphs(graph_path, [g])

    return feat_data, labels, train_idx, test_idx, g, cat_features
