import argparse
import multiprocessing
import os
import random
import sys
import time
from multiprocessing import Pool
# from GNN_local_refinement import run_multi_refinement, run_refinement_chain
import torch
import torch.nn as nn
from joblib import Parallel, delayed
from sklearn.linear_model import LogisticRegression
from torch import Tensor
from torch.serialization import add_safe_globals
from torch.utils.data import DataLoader
from controlsnr import find_a_given_snr
from data_generator_dcsbm import Generator, simple_collate_fn
from load import get_gnn_inputs
from losses import compute_loss_multiclass, compute_acc_ari_nmi, \
    gnn_compute_acc_ari_nmi_multiclass, from_scores_to_labels_multiclass_batch
from models import GNN_multiclass, GNN_multiclass_second_period
from spectral_clustering import spectral_clustering_adj, local_refinement_by_neighbors_multi
from train_first_period import train_first_period_with_early_stopping
from train_second_period import train_second_period_with_early_stopping
from pathlib import Path
import pandas as pd
import numpy as np

def setup_logger(prefix="main_gnn"):
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    pid = os.getpid()
    log_filename = f"{prefix}_{timestamp}_{pid}.log"
    logfile = open(log_filename, "w", buffering=1)  # 行缓冲
    sys.stdout = logfile
    sys.stderr = logfile
    print(f"[Logger initialized] Logging to: {log_filename}")

def load_best_model_into(model, ckpt_path, device, trust_checkpoint=True):
    """
    加载 ckpt 到已实例化的 model：
      1) 优先在 weights_only=True 下加载（安全）
      2) 如需自定义类，加入白名单后再试
      3) 仍失败且你确认 ckpt 可信 -> 回退 weights_only=False
    兼容：
      - torch.save({'model_state': state_dict, ...})
      - 直接保存的 state_dict
      - 旧格式：torch.save(model) 整个模型对象（不推荐，仅做兼容）
    """
    # ---- 1) 安全路径：只加载权重/张量 ----
    try:
        blob = torch.load(ckpt_path, map_location=device, weights_only=True)
    except Exception:
        # ---- 2) 需要白名单你的自定义类（weights_only=True 下仍可能需要）----
        try:
            add_safe_globals([GNN_multiclass])  # 可加入多个类
            blob = torch.load(ckpt_path, map_location=device, weights_only=True)
        except Exception:
            # ---- 3) 最后回退（仅可信文件）----
            if not trust_checkpoint:
                raise
            blob = torch.load(ckpt_path, map_location=device, weights_only=False)

    # ---- 解析多种常见 ckpt 结构 ----
    if isinstance(blob, dict):
        if 'model_state' in blob:
            state = blob['model_state']
        elif 'model_state_dict' in blob:
            state = blob['model_state_dict']
        elif 'state_dict' in blob:
            state = blob['state_dict']
        else:
            # 可能就是个纯 state_dict
            if all(isinstance(v, torch.Tensor) for v in blob.values()):
                state = blob
            else:
                raise RuntimeError(f"Unrecognized checkpoint dict keys: {list(blob.keys())[:8]}")
        missing, unexpected = model.load_state_dict(state, strict=False)
        if missing or unexpected:
            print(f"[load_state_dict] missing={missing}, unexpected={unexpected}")
        model.to(device).eval()
        return model

    # 兼容：旧式“整个模型对象”保存
    if hasattr(blob, 'state_dict'):
        try:
            missing, unexpected = model.load_state_dict(blob.state_dict(), strict=False)
            if missing or unexpected:
                print(f"[load_state_dict] missing={missing}, unexpected={unexpected}")
            model.to(device).eval()
            return model
        except Exception:
            # 架构对不上，直接返回反序列化的对象（仅当可信）
            mdl = blob.to(device)
            if isinstance(mdl, torch.nn.Module):
                mdl.eval()
            return mdl

    # 到这里说明既不是 dict 也不是 nn.Module
    raise RuntimeError(f"Unsupported checkpoint type: {type(blob)}")

def maybe_freeze(model, freeze=True):
    if freeze:
        for p in model.parameters():
            p.requires_grad = False
    model.eval()

parser = argparse.ArgumentParser()

###############################################################################
#                             General Settings                                #
#                          提前配置参数，方便后面使用                              #
###############################################################################

parser.add_argument('--num_examples_train', nargs='?', const=1, type=int,
                    default=int(6000))
parser.add_argument('--num_examples_test', nargs='?', const=1, type=int,
                    default=int(1))
parser.add_argument('--num_examples_val', nargs='?', const=1, type=int,
                    default=int(1000))
parser.add_argument('--edge_density', nargs='?', const=1, type=float,
                    default=0.2)
parser.add_argument('--p_SBM', nargs='?', const=1, type=float,
                    default=0.2)
parser.add_argument('--q_SBM', nargs='?', const=1, type=float,
                    default=0.1)
parser.add_argument('--class_sizes', type=int, nargs='+', default=[100, 1000],
                    help='List of class sizes for imbalanced SBM')
parser.add_argument('--random_noise', action='store_true')
parser.add_argument('--noise', nargs='?', const=1, type=float, default=2)
parser.add_argument('--noise_model', nargs='?', const=1, type=int, default=2)
#########################
#parser.add_argument('--generative_model', nargs='?', const=1, type=str,
#                    default='ErdosRenyi')
parser.add_argument('--generative_model', nargs='?', const=1, type=str,
                    default='SBM_multiclass')
parser.add_argument('--batch_size', nargs='?', const=1, type=int, default= 1)
parser.add_argument('--mode', nargs='?', const=1, type=str, default='train')
default_path = os.path.join(os.path.dirname(os.path.abspath(__file__)))
parser.add_argument('--mode_isbalanced', nargs='?', const=1, type=str, default='imbalanced')
parser.add_argument('--path_gnn', nargs='?', const=1, type=str, default=default_path)
parser.add_argument('--path_local_refinement', nargs='?', const=1, type=str, default='')

parser.add_argument('--filename_existing_gnn', nargs='?', const=1, type=str, default='')
parser.add_argument('--filename_existing_gnn_local_refinement', nargs='?', const=1, type=str, default='')

parser.add_argument('--print_freq', nargs='?', const=1, type=int, default=1)
parser.add_argument('--test_freq', nargs='?', const=1, type=int, default=500)
parser.add_argument('--save_freq', nargs='?', const=1, type=int, default=2000)
parser.add_argument('--clip_grad_norm', nargs='?', const=1, type=float,
                    default=10.0)
parser.add_argument('--freeze_bn', dest='eval_vs_train', action='store_true')
parser.set_defaults(eval_vs_train=True)

###############################################################################
#                                 GNN Settings                                #
###############################################################################

###############################################################################
#                                 GNN first period                            #
###############################################################################
parser.add_argument('--num_features', nargs='?', const=1, type=int,
                    default=16)
parser.add_argument('--num_layers', nargs='?', const=1, type=int,
                    default=30)
parser.add_argument('--J', nargs='?', const=1, type=int, default= 2)

###############################################################################
#                                 GNN second period                            #
###############################################################################
parser.add_argument('--num_features_second', nargs='?', const=1, type=int,
                    default= 8)
parser.add_argument('--num_layers_second', nargs='?', const=1, type=int,
                    default= 3)
parser.add_argument('--J_second', nargs='?', const=1, type=int, default= 1)

parser.add_argument('--n_classes', nargs='?', const=1, type=int,
                    default=4)
parser.add_argument('--N_train', nargs='?', const=1, type=int, default=1000)
parser.add_argument('--N_test', nargs='?', const=1, type=int, default=1000)
parser.add_argument('--N_val', nargs='?', const=1, type=int, default=1000)

parser.add_argument('--lr', nargs='?', const=1, type=float, default=4e-3)

args = parser.parse_args()

if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
    dtype_l = torch.cuda.LongTensor
    # torch.cuda.manual_seed(0)
else:
    dtype = torch.FloatTensor
    dtype_l = torch.LongTensor
    # torch.manual_seed(1)

batch_size = args.batch_size
criterion = nn.CrossEntropyLoss()
template1 = '{:<10} {:<10} {:<10} {:<15} {:<10} {:<10} {:<10}'
template2 = '{:<10} {:<10.5f} {:<10.5f} {:<15} {:<10} {:<10} {:<10.3f} \n'
template3 = '{:<10} {:<10} {:<10} '
template4 = '{:<10} {:<10.5f} {:<10.5f} \n'

template_header = '{:<6} {:<10} {:<10} {:<13} {:<10} {:<8} {:<10} {:<10} {:<20}'
template_row    = '{:<6} {:<10.5f} {:<10.5f} {:<13} {:<10} {:<8} {:<10.3f} {:<10.4f} {:<20}'



class SBMDataset(torch.utils.data.Dataset):
    def __init__(self, npz_file_list):
        self.files = npz_file_list

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        data = np.load(self.files[idx])
        adj = csr_matrix((data['adj_data'], data['adj_indices'], data['adj_indptr']), shape=tuple(data['adj_shape']))
        labels = data['labels']
        return {'adj': adj, 'labels': labels, 'num_nodes': adj.shape[0]}


# ✅ 正确的 imports
from scipy.sparse import csr_matrix
from spectral_clustering import local_refinement_by_neighbors
from scipy.sparse import issparse
from scipy.sparse.linalg import eigsh as sparse_eigsh
from numpy.linalg import eigh as dense_eigh



def extract_spectral_features(
    adj,
    k: int,
    *,
    row_norm: bool = True,
    use_abs: bool = True,   # True: 按 |λ| 选前 k；False: 按最大正特征值选前 k
):
    """
    对邻接矩阵 A 做谱分解，返回前 k 个特征向量 (N, k) 作为特征。

    参数
    ----
    adj : (N,N) 稀疏或稠密邻接矩阵（应为对称实矩阵）
    k   : 取的特征向量个数
    row_norm : 是否对每行做 L2 归一化（默认 True，常更稳）
    use_abs  : 是否按 |λ| 选取（默认 True）。若 False，则取最大的正特征值方向

    返回
    ----
    U : np.ndarray, 形状 (N, k)，A 的前 k 个（按 |λ| 或按最大正值）特征向量
    """
    if issparse(adj):
        adj = adj.astype(np.float64)  # ✅ 修复 int 类型报错
        # 稀疏：直接用 ARPACK；which="LM" -> largest magnitude
        k_eff = max(1, min(k, adj.shape[0]-2))  # ARPACK 要求 k < N
        if use_abs:
            w, U = sparse_eigsh(adj, k=k_eff, which="LM")  # |λ| 最大
        else:
            w, U = sparse_eigsh(adj, k=k_eff, which="LA")  # 最大代数值（偏正端）
        # 为了稳定，按选择准则再排一下列顺序
        order = np.argsort(np.abs(w) if use_abs else w)[::-1]
        U = U[:, order[:k]]

    else:
        A = np.asarray(adj, dtype=np.float64)
        # 稠密：先全谱（升序），再手动筛选
        w_all, V_all = dense_eigh(A)  # 升序
        if use_abs:
            idx = np.argsort(np.abs(w_all))[-k:]          # 取 |λ| 最大的 k
            idx = idx[np.argsort(np.abs(w_all[idx]))[::-1]]
        else:
            idx = np.argsort(w_all)[-k:]                  # 取最大的 k 个（偏正端）
            idx = idx[::-1]
        U = V_all[:, idx]

    # 列符号对齐（避免跑不同次整体翻号）
    col_mean = U.mean(axis=0, keepdims=True)
    U *= np.where(col_mean >= 0, 1.0, -1.0)

    # 可选行归一化（把每个节点向量长度拉到 1 附近，消掉幅值差异）
    if row_norm:
        U = U / (np.linalg.norm(U, axis=1, keepdims=True) + 1e-12)

    return U


import joblib

os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")

def _load_and_feat(path, k):
    data = np.load(path)
    adj = csr_matrix((data['adj_data'], data['adj_indices'], data['adj_indptr']),
                     shape=tuple(data['adj_shape']))
    labels = np.asarray(data['labels']).ravel().astype(int)
    feats = extract_spectral_features(adj, k)   # ⬅️ 改成稀疏 eigsh 的版本
    return feats, labels

def train_logistic_regression(gen, k, save_path="./models/node_info_added_lr_model.pkl", n_jobs=os.cpu_count()//2):
    # 如果模型已存在 → 直接加载
    if os.path.exists(save_path):
        print(f"检测到已存在的模型文件 {save_path}，直接加载。")
        return joblib.load(save_path)

    # 并行提取特征
    results = Parallel(n_jobs=n_jobs, prefer="processes")(
        delayed(_load_and_feat)(fp, k) for fp in gen.data_train
    )
    all_features, all_labels = zip(*results)

    X_train = np.vstack(all_features).astype(np.float32, copy=False)
    y_train = np.concatenate(all_labels)

    # 训练逻辑回归模型
    lr_model = LogisticRegression(
        multi_class='multinomial',
        solver='saga',  # 支持大规模 & 稀疏数据
        max_iter=1000,
        random_state=42
    )

    lr_model.fit(X_train, y_train)

    # 保存模型
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    joblib.dump(lr_model, save_path)
    print(f"模型已保存到 {save_path}")

    return lr_model


def evaluate_using_logitic(W, true_labels ,lr_model, k):
    """
    在测试集上评估模型
    :param W: The graph we want to test
    :param true_labels: The true labels
    :param lr_model: The well-trained logistic regression model
    :param k: The number of communities
    :return: The average accuracy
    """
    W = W.squeeze(0)

    # 提取谱特征
    features = extract_spectral_features(W, k)

    # 预测标签
    pred_labels = lr_model.predict(features)

    # 计算准确率
    acc_logistic, logistic_best_matched_pred, ari_logistic, nmi_logistic = compute_acc_ari_nmi(pred_labels, true_labels, k)
    # best_acc, best_pred, ari, nmi

    # Local refinement
    logistic_best_matched_pred_refined = local_refinement_by_neighbors(W, logistic_best_matched_pred, k)
    acc_logistic_refined, best_pred, ari_logistic_refined, nmi_logistic_refined = compute_acc_ari_nmi(logistic_best_matched_pred_refined, true_labels, k)

    return acc_logistic, ari_logistic, nmi_logistic, acc_logistic_refined, ari_logistic_refined, nmi_logistic_refined

def get_available_device():
    for i in range(torch.cuda.device_count()):
        try:
            # 尝试分配临时张量测试显存
            torch.cuda.set_device(i)
            torch.zeros(1).cuda()
            return torch.device(f"cuda:{i}")
        except RuntimeError:
            continue
    return torch.device("cpu")


device = get_available_device()


# ##############Get the labels from the first period and use this labels to train the local refinement###################
# def train_single_local_refinement(gnn_first_period, gnn_local_refine, n_classes, W, true_labels ,optimizer, iter):
#
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#
#     sam_com_matrix, pred_labels = get_second_period_labels_single(gnn_first_period, W, true_labels,n_classes, args)
#
#     true_labels = true_labels.type(dtype_l)
#
#     start = time.time()
#     WW, x = get_gnn_inputs_local_refinement(W, args.J, sam_com_matrix, pred_labels, n_classes)
#
#     # **移动数据到计算设备 (GPU/CPU)**
#     WW = WW.to(device)
#     x = x.to(device)
#
#     optimizer.zero_grad(set_to_none=True)
#     pred = gnn_local_refine(WW.type(dtype), x.type(dtype))
#
#     loss = compute_loss_multiclass(pred, true_labels, n_classes)  # 计算损失
#     loss.backward()
#     nn.utils.clip_grad_norm_(gnn_local_refine.parameters(), args.clip_grad_norm)
#     optimizer.step()
#
#     acc, best_matched_pred = compute_accuracy_multiclass(pred, true_labels, n_classes)
#
#     elapsed_time = time.time() - start
#
#     if torch.cuda.is_available():
#         loss_value = float(loss.data.cpu().numpy())
#     else:
#         loss_value = float(loss.data.numpy())
#
#     info = ['iter', 'avg loss', 'avg acc', 'edge_density',
#             'noise', 'model', 'elapsed']
#     out = [iter, loss_value, acc, args.edge_density,
#            args.noise, 'GNN', elapsed_time]
#
#     print(template1.format(*info))
#     print(template2.format(*out))
#
#     # **释放 GPU 显存**
#     WW = None
#     x = None
#
#     return loss_value, acc
#
#
# def train_local_refinement(gnn_first_period, gnn_local_refine, n_classes=args.n_classes, iters=args.num_examples_train):
#     gnn_local_refine.train()
#     optimizer = torch.optim.Adamax(gnn_local_refine.parameters(), lr=args.lr)
#
#     loss_lst = np.zeros([iters])
#     acc_lst = np.zeros([iters])
#
#     for it in range(iters):
#         W_i = cached_graphs[it]
#         true_labels_i = cached_labels[it]
#
#         loss_single, acc_single = train_single_local_refinement(gnn_first_period, gnn_local_refine, n_classes, W_i, true_labels_i,
#                                                                 optimizer, it)
#
#         loss_lst[it] = loss_single
#         acc_lst[it] = acc_single
#
#         torch.cuda.empty_cache()

# def get_insample_acc_lost_single(gnn_first_period, gnn_local_refine, W_i, labels_i, n_classes, args):
#
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#
#     W, true_labels = W_i, labels_i
#
#     sam_com_matrix, pred_label_first ,first_in_sample_loss, first_in_sample_acc \
#      = in_sample_test_first_get_second_period_labels_single(gnn_first_period, W, true_labels,n_classes, args)
#
#     WW, x = get_gnn_inputs_local_refinement(W, args.J, sam_com_matrix, pred_label_first, n_classes)
#
#     if torch.cuda.is_available():
#         WW = WW.to(device)
#         x = x.to(device)
#
#     pred_single_second = gnn_local_refine(WW.type(dtype), x.type(dtype))
#
#     # 计算第二阶段损失和准确率
#     second_in_sample_loss = compute_loss_multiclass(pred_single_second, true_labels, n_classes)
#     second_in_sample_acc, best_matched_pred = compute_accuracy_multiclass(pred_single_second, true_labels, n_classes)
#
#     WW = None
#     x = None
#
#     return first_in_sample_loss, first_in_sample_acc, float(second_in_sample_loss.data.cpu().numpy()), second_in_sample_acc
#
# def get_insample_acc_lost(gnn_first_period, gnn_local_refinement, n_classes, iters=args.num_examples_test,
#                           filename="in_sample_test_results_sparsity.csv"):
#
#     gnn_first_period.train()
#     gnn_local_refinement.train()
#
#     # in_sample_loss_lst_first = np.zeros([iters])
#     in_sample_acc_lst_first = np.zeros([iters])
#
#     # in_sample_loss_lst_second = np.zeros([iters])
#     in_sample_acc_lst_second = np.zeros([iters])
#
#     for it in range(iters):
#         W_i = cached_graphs[it]
#         labels_i = cached_labels[it]
#
#         first_in_sample_loss, first_in_sample_acc, second_in_sample_loss, second_in_sample_acc = get_insample_acc_lost_single(
#             gnn_first_period,
#             gnn_local_refinement, W_i, labels_i, n_classes, args)
#
#         in_sample_acc_lst_first[it] = first_in_sample_acc
#
#         in_sample_acc_lst_second[it] = second_in_sample_acc
#
#         torch.cuda.empty_cache()
#     # 计算均值和标准差
#     first_avg_test_acc = np.mean(in_sample_acc_lst_first)
#     first_std_test_acc = np.std(in_sample_acc_lst_first)
#
#     second_avg_test_acc = np.mean(in_sample_acc_lst_second)
#     second_std_test_acc = np.std(in_sample_acc_lst_second)
#
#     n = args.N_train  # 或者 N_test，也可以统一都用 N
#     logn_div_n = np.log(n) / n
#
#     a = args.p_SBM / logn_div_n
#     b = args.q_SBM / logn_div_n
#     k = args.n_classes
#
#     snr = (a - b) ** 2 / (k * (a + (k - 1) * b))
#
#     df = pd.DataFrame([{
#         "n_classes": args.n_classes,
#         "p_SBM": args.p_SBM,
#         "q_SBM": args.q_SBM,
#         "J": args.J,
#         "N_train": args.N_train,
#         "N_test": args.N_test,
#         "first_avg_test_acc": first_avg_test_acc,
#         "first_std_test_acc": first_std_test_acc,
#         "second_avg_test_acc": second_avg_test_acc,
#         "second_std_test_acc": second_std_test_acc,
#         "SNR": snr
#     }])
#
#     # 追加模式写入文件，防止覆盖
#     df.to_csv(filename, mode='a', index=False, header=not pd.io.common.file_exists(filename))

def test_single_first_period(gnn_first_period, lr_model, gen, n_classes, args, iter, mode='balanced', class_sizes=None):
    start = time.time()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    gnn_first_period.train()

    # 选择模式
    if mode == 'imbalanced':
        if random.random() < 0.5:
            random_class_sizes = class_sizes[::-1]  # 翻转成 [950, 50]
        else:
            random_class_sizes = class_sizes

        W, true_labels, eigvecs_top = gen.imbalanced_sample_otf_single(random_class_sizes, is_training=True, cuda=True)
        true_labels = true_labels.type(dtype_l)

    res_spectral = spectral_clustering_adj(
        W, n_classes, true_labels,
        normalized=True,
        run_all=True,
        random_state=0
    )
    # (A, k, true_labels, normalized: bool = False, *, run_all: bool = False, random_state: int = 0):

    acc_logistic, ari_logistic, nmi_logistic, acc_logistic_refined, ari_logistic_refined, nmi_logistic_refined\
        = evaluate_using_logitic(W, true_labels, lr_model, n_classes)

    # GNN 输入
    WW, x = get_gnn_inputs(W, args.J)
    WW, x = WW.to(device), x.to(device)

    # 禁用梯度计算
    with torch.no_grad():
        pred_single_first = gnn_first_period(WW.type(dtype), x.type(dtype))

    # --- 3. 计算邻接矩阵特征向量 ---
    W_np = W.squeeze(0).cpu().numpy() if isinstance(W, Tensor) else W.squeeze(0)

    W_for_eig = (W_np + W_np.T) / 2  # 确保对称
    eigvals_W, eigvecs_W = np.linalg.eigh(W_for_eig)
    eigvals_W, eigvecs_W = np.real(eigvals_W), np.real(eigvecs_W)
    adjacency_eigvecs = eigvecs_W[:, np.argsort(eigvals_W)[-n_classes:][::-1]]
    adjacency_eigvecs /= np.linalg.norm(adjacency_eigvecs, axis=0, keepdims=True)

    # 中间层特征
    penultimate_features = gnn_first_period.get_penultimate_output().detach().cpu().numpy().squeeze(0)
    penultimate_features /= np.linalg.norm(penultimate_features, axis=0, keepdims=True)

    # 第二阶段 loss 和 acc
    loss_test_first = compute_loss_multiclass(pred_single_first, true_labels, n_classes)
    acc_gnn_first, best_matched_pred, ari_gnn_first, nmi_gnn_first = gnn_compute_acc_ari_nmi_multiclass(pred_single_first, true_labels, n_classes)
    # acc_mean, best_matched_preds, ari_mean, nmi_mean
    gnn_pred_label_first = best_matched_pred
    # Local refinement
    gnn_refined = local_refinement_by_neighbors(W_np, gnn_pred_label_first, n_classes)

    acc_gnn_refined, best_matched_pred, ari_gnn_refined, nmi_gnn_refined = compute_acc_ari_nmi(gnn_refined, true_labels, n_classes)


    N = true_labels.shape[1]
    elapsed = time.time() - start

    if(torch.cuda.is_available()):
        loss_value = float(loss_test_first.data.cpu().numpy())
    else:
        loss_value = float(loss_test_first.data.numpy())

    info = ['iter', 'avg loss', 'avg acc', 'edge_density',
            'noise', 'model', 'elapsed']
    out = [iter, loss_value, acc_gnn_first, args.edge_density,
           args.noise, 'GNN', elapsed]
    print(template1.format(*info))
    print(template2.format(*out))

    del WW
    del x

    # 构造 Excel 表
    data = {
        'True_Label': true_labels.squeeze(0).cpu().numpy(),
        'Pred_Label_First': gnn_pred_label_first.reshape(-1),
        'Loss_First': [float(loss_test_first)] * N,
        'Acc_First': [float(acc_gnn_first)] * N,
    }

    for i in range(2 * n_classes):
        data[f'penultimate_GNN_Feature{i + 1}'] = penultimate_features[:, i]
    for i in range(n_classes):
        data[f'eigvecs_top{i + 1}'] = eigvecs_top[:, i]
    for i in range(n_classes):
        data['Adj_EigVecs_Top' + str(i + 1) + ''] = adjacency_eigvecs[:, i]

    df = pd.DataFrame(data)

    # 写入 Excel
    root_folder = "penultimate_GNN_Feature"
    subfolder_name = f"penultimate_GNN_Feature_nclasses_{n_classes}"
    output_filename = (
        f"first_gnn_classesizes={class_sizes}_p={gen.p_SBM}_q={gen.q_SBM}_"
        f"j={args.J}_nlyr={args.num_layers}.xlsx"
    )
    output_path = os.path.join(root_folder, subfolder_name, output_filename)
    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    if iter < 10:
        if iter == 0:
            df.to_excel(output_path, sheet_name=f'Iteration_{iter}', index=False)
        else:
            with pd.ExcelWriter(output_path, mode='a', engine='openpyxl') as writer:
                df.to_excel(writer, sheet_name=f'Iteration_{iter}', index=False)

    # 拆三种谱方法的结果
    acc_sc_n, ari_sc_n, nmi_sc_n = res_spectral['normalized']['sc']
    acc_rf_n, ari_rf_n, nmi_rf_n = res_spectral['normalized']['refined']

    acc_sc_u, ari_sc_u, nmi_sc_u = res_spectral['unnormalized']['sc']
    acc_rf_u, ari_rf_u, nmi_rf_u = res_spectral['unnormalized']['refined']

    acc_sc_a, ari_sc_a, nmi_sc_a = res_spectral['adjacency']['sc']
    acc_rf_a, ari_rf_a, nmi_rf_a = res_spectral['adjacency']['refined']

    metrics = {
        "gnn": {
            "acc": float(acc_gnn_first),
            "ari": float(ari_gnn_first),
            "nmi": float(nmi_gnn_first),
        },
        "gnn_refined": {
            "acc": float(acc_gnn_refined),
            "ari": float(ari_gnn_refined),
            "nmi": float(nmi_gnn_refined),
        },

        # ——三种谱方法（原始）——
        "spectral_normalized": {"acc": float(acc_sc_n), "ari": float(ari_sc_n), "nmi": float(nmi_sc_n)},
        "spectral_normalized_refined": {"acc": float(acc_rf_n), "ari": float(ari_rf_n), "nmi": float(nmi_rf_n)},

        "spectral_unnormalized": {"acc": float(acc_sc_u), "ari": float(ari_sc_u), "nmi": float(nmi_sc_u)},
        "spectral_unnormalized_refined": {"acc": float(acc_rf_u), "ari": float(ari_rf_u), "nmi": float(nmi_rf_u)},

        # ——三种谱方法（refined）——
        "spectral_adjacency": {"acc": float(acc_sc_a), "ari": float(ari_sc_a), "nmi": float(nmi_sc_a)},
        "spectral_adjacency_refined": {"acc": float(acc_rf_a), "ari": float(ari_rf_a), "nmi": float(nmi_rf_a)},

        # ——你原来的逻辑回归部分——
        "logistic": {
            "acc": float(acc_logistic),
            "ari": float(ari_logistic),
            "nmi": float(nmi_logistic),
        },
        "logistic_refined": {
            "acc": float(acc_logistic_refined),
            "ari": float(ari_logistic_refined),
            "nmi": float(nmi_logistic_refined),
        },
    }

    return float(loss_test_first), metrics

# def test_single_local_refinement(gnn_first_period, gnn_second_period, A, labels, n_classes, args, iter):
#     """
#     A: numpy.ndarray (N,N)  # 建议确保传入就是 numpy 2D
#     labels: numpy.ndarray (N,) 或 (N,1)，取值任意（如 {1,2}、{3,5,7}）
#     """
#     start = time.time()
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#
#     # === 0) 统一输入形状 ===
#     labels = np.asarray(labels).squeeze()                 # -> (N,)
#     A_np   = np.asarray(A)                                # -> (N,N)
#     assert A_np.ndim == 2 and A_np.shape[0] == A_np.shape[1], f"A shape bad: {A_np.shape}"
#     N = A_np.shape[0]
#
#     # === 1) 标签做通用映射到 0..K-1 ===
#     uniq = np.unique(labels)
#     remap = {v: i for i, v in enumerate(uniq)}
#     labels_mapped = np.vectorize(remap.get)(labels).astype(np.int64)
#     true_labels = torch.tensor(labels_mapped, dtype=torch.long, device=device).unsqueeze(0)  # (1,N)
#
#     # === 2) 谱 & logistic 基线（用 numpy 2D 邻接） ===
#     res_spectral = spectral_clustering_adj(
#         A_np, n_classes, true_labels,  # 若内部期望 1D，可传 true_labels.squeeze(0).cpu().numpy()
#         normalized=True,
#         run_all=True,
#         random_state=0
#     )
#
#     # === 3) GNN 输入 ===
#     W_np = np.expand_dims(A_np, 0)      # shape (1, N, N)
#     WW, x = get_gnn_inputs(W_np, args.J)
#     WW, x = WW.to(device), x.to(device)
#
#     gnn_first_period.train()
#     gnn_second_period.train()
#
#     with torch.no_grad():
#         # --- 阶段一 ---
#         pred_single_first = gnn_first_period(WW.type(torch.float32), x.type(torch.float32))
#         start_x_label = from_scores_to_labels_multiclass_batch(
#             pred_single_first.detach().cpu().numpy()
#         )
#         print("多重GNN_second_period开始")
#         # === 多步 refinement（一次性） ===
#         # total_iters = 5 等价于你原来“先 1 次 + 再 4 次”的 5 层流程
#         ref = run_refinement_chain(
#             gnn_second_period=gnn_second_period,
#             W_np=W_np,
#             init_labels=start_x_label,
#             true_labels=true_labels,
#             args=args,
#             device=device,
#             total_iters=5,     # 需要更改层数时，改这里
#             verbose=True
#         )
#
#         # 兼容你原来的“第二层（第一次 refinement）命名”
#         if ref["first_iter"] is not None:
#             second_pred_label = ref["first_iter"]["pred_label"]
#             second_pred       = ref["first_iter"]["pred"]
#             second_acc        = ref["first_iter"]["acc"]
#             second_ari        = ref["first_iter"]["ari"]
#             second_nmi        = ref["first_iter"]["nmi"]
#             # print(f"[Iter 1] acc={second_acc:.4f}, ari={second_ari:.4f}, nmi={second_nmi:.4f}")
#
#         # 最终（第 total_iters 次）结果
#         final_pred = ref["final"]["pred"]
#         final_acc  = ref["final"]["acc"]
#         final_ari  = ref["final"]["ari"]
#         final_nmi  = ref["final"]["nmi"]
#         print("[Final] "
#               f"acc={final_acc:.4f}, ari={final_ari:.4f}, nmi={final_nmi:.4f}")
#
#     # === 4) 邻接矩阵特征向量（numpy） ===
#     W_for_eig = (A_np + A_np.T) / 2.0
#     eigvals_W, eigvecs_W = np.linalg.eigh(W_for_eig)
#     idx = np.argsort(eigvals_W)[-n_classes:][::-1]
#     adjacency_eigvecs = eigvecs_W[:, idx]
#     adjacency_eigvecs /= (np.linalg.norm(adjacency_eigvecs, axis=0, keepdims=True) + 1e-12)
#
#     # === 5) GNN 倒数第二层特征（按你的写法：列归一化） ===
#     penultimate_features = gnn_first_period.get_penultimate_output().detach().cpu().numpy().squeeze(0)  # (N, D)
#     penultimate_features /= (np.linalg.norm(penultimate_features, axis=0, keepdims=True) + 1e-12)
#
#     penultimate_features_second = gnn_second_period.get_penultimate_output().detach().cpu().numpy().squeeze(0)  # (N, D)
#     penultimate_features_second /= (np.linalg.norm(penultimate_features_second, axis=0, keepdims=True) + 1e-12)
#
#     # === 6) GNN loss & acc ===
#     loss_test_first = compute_loss_multiclass(pred_single_first, true_labels, n_classes)
#     acc_gnn_first, best_matched_pred, ari_gnn_first, nmi_gnn_first = \
#         gnn_compute_acc_ari_nmi_multiclass(pred_single_first, true_labels, n_classes)
#     gnn_pred_label_first = best_matched_pred  # (N,)
#
#     loss_test_second = compute_loss_multiclass(second_pred, true_labels, n_classes)
#     acc_gnn_second, best_matched_pred, ari_gnn_second, nmi_gnn_second = \
#         gnn_compute_acc_ari_nmi_multiclass(second_pred, true_labels, n_classes)
#     gnn_pred_label_second = best_matched_pred  # (N,)
#
#     # local refinement（A 用 numpy，labels 用 1D）
#     gnn_refined = local_refinement_by_neighbors(A_np, gnn_pred_label_first, n_classes)
#     acc_gnn_refined, best_matched_pred , ari_gnn_refined, nmi_gnn_refined = \
#         compute_acc_ari_nmi(gnn_refined, true_labels, n_classes)
#     gnn_refined_label = best_matched_pred
#     print("多重GNN_local_refinement开始")
#     # 多跑 10 步 refinement，记录每步 acc/ARI/NMI 与变化数量
#     res_multi = local_refinement_by_neighbors_multi(
#         A=A_np,  # (N,N) numpy
#         init_labels=start_x_label,  # (N,)
#         num_classes=n_classes,
#         true_labels=true_labels,  # 你的 (1,N) torch.LongTensor 或 (N,) numpy
#         num_iters=10,
#         alpha=1e-6,
#         random_state=0,
#         tol=0,  # 不提前停
#         verbose=True,
#         return_history=True
#     )
#
#     elapsed = time.time() - start
#     loss_value = float(loss_test_first.detach().cpu().numpy())
#     print(f"[Iter {iter}] Loss={loss_value:.4f}, Acc={acc_gnn_first:.4f}, Time={elapsed:.2f}s")
#
#     # === 7) 保存 Excel（按列名维度自适配） ===
#     data = {
#         'True_Label': true_labels.squeeze(0).cpu().numpy(),          # (N,)
#         'Pred_Label_First': gnn_pred_label_first.reshape(-1),        # (N,)
#         'Pred_Label_Second': gnn_pred_label_second.reshape(-1),
#         'Pred_Label_Refined': gnn_refined_label.reshape(-1),
#         'Loss_First': [loss_value] * N,
#         'Acc_First': [float(acc_gnn_first)] * N,
#     }
#
#     Dp = penultimate_features.shape[1]
#     for i in range(Dp):
#         # data[f'penultimate_GNN_Feature{i + 1}'] = penultimate_features[:, i]
#         data[f'pen_GNN_Feature_second{i + 1}'] = penultimate_features_second[:, i]
#
#     for i in range(adjacency_eigvecs.shape[1]):
#         data[f'Adj_EigVecs_Top{i + 1}'] = adjacency_eigvecs[:, i]
#
#     df = pd.DataFrame(data)
#     root_folder = "penultimate_GNN_Feature"
#     subfolder_name = f"penultimate_GNN_Feature_nclasses_{n_classes}"
#     output_filename = f"first_gnn_fromA_iter{iter}_j={args.J}_nlyr={args.num_layers}.xlsx"
#     output_path = os.path.join(root_folder, subfolder_name, output_filename)
#     os.makedirs(os.path.dirname(output_path), exist_ok=True)
#
#     df.to_excel(output_path, sheet_name=f'Iteration_{iter}', index=False)
#
#     # === 8) 汇总指标 ===
#     acc_sc_n, ari_sc_n, nmi_sc_n = res_spectral['normalized']['sc']
#     acc_rf_n, ari_rf_n, nmi_rf_n = res_spectral['normalized']['refined']
#     acc_sc_u, ari_sc_u, nmi_sc_u = res_spectral['unnormalized']['sc']
#     acc_rf_u, ari_rf_u, nmi_rf_u = res_spectral['unnormalized']['refined']
#     acc_sc_a, ari_sc_a, nmi_sc_a = res_spectral['adjacency']['sc']
#     acc_rf_a, ari_rf_a, nmi_rf_a = res_spectral['adjacency']['refined']
#     acc_sc_svt, ari_sc_svt, nmi_sc_svt = res_spectral['svt']['sc']
#     acc_rf_svt, ari_rf_svt, nmi_rf_svt = res_spectral['svt']['refined']
#
#     metrics = {
#         "gnn_first": {"acc": float(acc_gnn_first), "ari": float(ari_gnn_first), "nmi": float(nmi_gnn_first)},
#         "gnn_second": {"acc": float(acc_gnn_second), "ari": float(ari_gnn_second), "nmi": float(nmi_gnn_second)},
#         "gnn_refined": {"acc": float(acc_gnn_refined), "ari": float(ari_gnn_refined), "nmi": float(nmi_gnn_refined)},
#         "spectral_normalized": {"acc": float(acc_sc_n), "ari": float(ari_sc_n), "nmi": float(nmi_sc_n)},
#         "spectral_normalized_refined": {"acc": float(acc_rf_n), "ari": float(ari_rf_n), "nmi": float(nmi_rf_n)},
#         "spectral_unnormalized": {"acc": float(acc_sc_u), "ari": float(ari_sc_u), "nmi": float(nmi_sc_u)},
#         "spectral_unnormalized_refined": {"acc": float(acc_rf_u), "ari": float(ari_rf_u), "nmi": float(nmi_rf_u)},
#         "spectral_adjacency": {"acc": float(acc_sc_a), "ari": float(ari_sc_a), "nmi": float(nmi_sc_a)},
#         "spectral_adjacency_refined": {"acc": float(acc_rf_a), "ari": float(ari_rf_a), "nmi": float(nmi_rf_a)},
#         "spectral_svt": {"acc": float(acc_sc_svt), "ari": float(ari_sc_svt), "nmi": float(nmi_sc_svt)},
#         "spectral_svt_refined": {"acc": float(acc_rf_svt), "ari": float(ari_rf_svt), "nmi": float(nmi_rf_svt)},
#     }
#
#     return loss_value, metrics


METHODS = ["gnn", "gnn_refined",

           "spectral_normalized", "spectral_normalized_refined",

           "spectral_unnormalized", "spectral_unnormalized_refined",

           "spectral_adjacency","spectral_adjacency_refined",

           "logistic", "logistic_refined"]


def test_first_period(
    gnn_first_period,
    gnn_second_period,
    lr_model,
    n_classes,
    gen,
    args,
    iters=None,
    mode='balanced',
    class_sizes=None,
):
    """
    运行 iters 次评测，返回三行 dict（ACC/ARI/NMI 各一行）。
    每行的列：元信息 + 各方法的均值（列为 METHODS）
    不写文件，交给 append_rows_to_excel 处理。
    依赖：test_single_first_period(...) -> (loss, metrics) 且 metrics 结构：
        metrics = {
          "gnn": {"acc": ..., "ari": ..., "nmi": ...},
          "spectral": {...}, "spectral_refined": {...},
          "logistic": {...}, "logistic_refined": {...}
        }
    """
    if iters is None:
        iters = args.num_examples_test

    gnn_first_period.train()

    # 收集器： per-metric per-method 序列
    buckets = {m: {"acc": [], "ari": [], "nmi": []} for m in METHODS}

    for it in range(iters):
        loss_val, metrics = test_single_first_period(
            gnn_first_period=gnn_first_period,
            gnn_second_period=gnn_second_period,
            lr_model=lr_model,
            gen=gen,
            n_classes=n_classes,
            args=args,
            iter=it,
            mode=mode,
            class_sizes=class_sizes,
        )

        for meth in METHODS:
            md = metrics.get(meth, {})
            for k in ("acc", "ari", "nmi"):
                v = md.get(k, np.nan)
                buckets[meth][k].append(float(v) if np.isfinite(v) else np.nan)

        torch.cuda.empty_cache()

    # 只取均值
    def mean(vals):
        arr = np.asarray(vals, dtype=float)
        return float(np.nanmean(arr))

    # 元信息
    n = getattr(args, "N_train", getattr(gen, "N_train", None)) or getattr(args, "N_test", getattr(gen, "N_test", -1))
    logn_div_n = np.log(n) / n if n and n > 0 else np.nan
    a = gen.p_SBM / logn_div_n if np.isfinite(logn_div_n) else np.nan
    b = gen.q_SBM / logn_div_n if np.isfinite(logn_div_n) else np.nan
    k = n_classes
    snr = (a - b) ** 2 / (k * (a + (k - 1) * b)) if np.all(np.isfinite([a, b])) else np.nan

    meta = {
        "n_classes": int(n_classes),
        "class_sizes": str(class_sizes) if class_sizes is not None else "",
        "p_SBM": float(gen.p_SBM),
        "q_SBM": float(gen.q_SBM),
        "J": int(args.J),
        "N_train": int(getattr(args, "N_train", getattr(gen, "N_train", -1))),
        "N_test": int(getattr(args, "N_test", getattr(gen, "N_test", -1))),
        "SNR": float(snr) if np.isfinite(snr) else np.nan,
    }

    # 组装三行（ACC/ARI/NMI）
    row_acc = {**meta}
    row_ari = {**meta}
    row_nmi = {**meta}
    for meth in METHODS:
        row_acc[meth] = mean(buckets[meth]["acc"])
        row_ari[meth] = mean(buckets[meth]["ari"])
        row_nmi[meth] = mean(buckets[meth]["nmi"])

    return row_acc, row_ari, row_nmi

def append_rows_to_excel(row_acc: dict, row_ari: dict, row_nmi: dict, filename="summary.xlsx", extra_info: dict=None):
    """
    把三行分别追加到 Excel 的 ACC / ARI / NMI 三个 sheet。
    - 首次创建文件会带表头；存在则读取合并再写回（自动对齐列）。
    - extra_info（如 {"class_sizes":[100,900], "snr":0.5, "total_ab":10}）会并入到三张表的该行中。
      为便于筛选，class_sizes 若为 list/array 会写成 '100-900' 字符串。
    """
    def _normalize_extra(ei: dict):
        if not ei:
            return {}
        ei = dict(ei)
        if "class_sizes" in ei:
            cs = ei["class_sizes"]
            if isinstance(cs, (list, tuple, np.ndarray)):
                ei["class_sizes"] = "-".join(map(str, cs))
            else:
                ei["class_sizes"] = str(cs)
        return ei

    def _append_row(sheet_name: str, row: dict):
        # 合并 extra_info
        merged = dict(row)
        merged.update(_normalize_extra(extra_info))

        df_new = pd.DataFrame([merged])
        path = Path(filename)

        if path.exists():
            try:
                df_old = pd.read_excel(path, sheet_name=sheet_name)
                # 对齐列
                cols = list(dict.fromkeys(list(df_old.columns) + list(df_new.columns)))
                df_old = df_old.reindex(columns=cols)
                df_new = df_new.reindex(columns=cols)
                df_out = pd.concat([df_old, df_new], ignore_index=True)
            except ValueError:
                # 指定的 sheet 不存在，则直接创建
                df_out = df_new
            # 追加/替换写回
            with pd.ExcelWriter(filename, engine="openpyxl", mode="a", if_sheet_exists="replace") as writer:
                df_out.to_excel(writer, sheet_name=sheet_name, index=False)
        else:
            # 新建文件
            with pd.ExcelWriter(filename, engine="openpyxl", mode="w") as writer:
                df_new.to_excel(writer, sheet_name=sheet_name, index=False)

    _append_row("ACC", row_acc)
    _append_row("ARI", row_ari)
    _append_row("NMI", row_nmi)



def test_first_period_wrapper(args_tuple):
    gnn_first_period, gnn_second_period ,lr_model ,class_sizes, snr, gen, logN_div_N, total_ab= args_tuple

    a, b = find_a_given_snr(snr, args.n_classes, total_ab)
    p_SBM = round(a * logN_div_N, 4)
    q_SBM = round(b * logN_div_N, 4)

    # 每个进程用自己的 gen 副本（你可以用 copy 或重新初始化）
    gen_local = gen.copy()  # 确保你实现了 Generator.copy() 方法
    gen_local.p_SBM = p_SBM
    gen_local.q_SBM = q_SBM

    print(f"\n[测试阶段] class_sizes: {class_sizes}, SNR: {snr:.2f}")
    print(f"使用的 SBM 参数: p={p_SBM}, q={q_SBM}")

    row_acc, row_ari, row_nmi = test_first_period(
        gnn_first_period=gnn_first_period,
        gnn_second_period=gnn_second_period,
        lr_model=lr_model,
        n_classes=args.n_classes,
        gen=gen_local,
        args=args,
        iters=args.num_examples_test,
        mode=args.mode_isbalanced,   # 'balanced' 或 'imbalanced'
        class_sizes=class_sizes,
    )

    return row_acc, row_ari, row_nmi


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)



if __name__ == '__main__':
    try:
        setup_logger("run_gnn")
        gen = Generator(N_train=args.N_train, N_test=args.N_test,N_val=args.N_val,n_classes=args.n_classes)


        gen.edge_density = args.edge_density
        gen.p_SBM = args.p_SBM
        gen.q_SBM = args.q_SBM

        gen.random_noise = args.random_noise
        gen.noise = args.noise
        gen.noise_model = args.noise_model
        gen.generative_model = args.generative_model
        # gen.n_classes = args.n_classes

        gen.num_examples_train = args.num_examples_train
        gen.num_examples_test = args.num_examples_test
        gen.num_examples_val = args.num_examples_val

        # 1. 创建总模型文件夹
        root_model_dir = "model_GNN"
        os.makedirs(root_model_dir, exist_ok=True)

        # 2. 创建子目录（按 n_classes 分类）
        folder_name = f"GNN_model_first_classes{args.n_classes}"
        full_save_dir = os.path.join(root_model_dir, folder_name)
        os.makedirs(full_save_dir, exist_ok=True)

        # 3. 构造保存路径
        filename_first = (
            f'gnn_J{args.J}_lyr{args.num_layers}_classes{args.n_classes}_numfeatures{args.num_features}'
        )
        path_first = os.path.join(full_save_dir, filename_first)

        filename_second = (
            f'local+refin_gnn_J{args.J_second}_lyr{args.num_layers_second}_classes{args.n_classes}_numfeatures{args.num_features_second}'
        )

        path_second = os.path.join(full_save_dir, filename_second)

        if args.mode == "train":
            ################################################################################################################
            # Here is the train period, prepare the dataloader we need to train
            gen.prepare_data()

            # 1. 准备并行线程数
            num_workers = min(4, multiprocessing.cpu_count() - 1)
            print("num_workers", num_workers)

            # 2. 创建 Dataset 实例
            train_dataset = SBMDataset(gen.data_train)
            val_dataset = SBMDataset(gen.data_val)
            test_dataset = SBMDataset(gen.data_test)

            # 3. 创建对应的 DataLoader
            train_loader = DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=num_workers,
                collate_fn=simple_collate_fn
            )

            val_loader = DataLoader(
                val_dataset,
                batch_size=args.batch_size,
                shuffle=False,  # ❗验证集不需要打乱顺序
                num_workers=num_workers,
                collate_fn=simple_collate_fn
            )

            test_loader = DataLoader(
                test_dataset,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=num_workers,
                collate_fn=simple_collate_fn

            )

            print(
                f"[阶段一] 训练 GNN：n_classes={args.n_classes}, layers={args.num_layers}, J={args.J}, num_features={args.num_features}")

            # === 阶段一 ===
            if os.path.exists(path_first):
                print(f"[阶段一] 检测到已有‘最优权重’，直接载入: {path_first}")
                # 无论之前是否训练过，都确保 gnn_first_period 存在并载入“最优”
                if args.generative_model == 'SBM_multiclass':
                    gnn_first_period = GNN_multiclass(args.num_features, args.num_layers, args.J + 3,
                                                      n_classes=args.n_classes)
                gnn_first_period = gnn_first_period.to(device)
                gnn_first_period = load_best_model_into(gnn_first_period, path_first, device)
            else:
                # 初始化并训练第一阶段
                if args.generative_model == 'SBM_multiclass':
                    gnn_first_period = GNN_multiclass(args.num_features, args.num_layers, args.J + 3,
                                                      n_classes=args.n_classes)
                if torch.cuda.is_available():
                    gnn_first_period = gnn_first_period.to(device)

                loss_list, acc_list = train_first_period_with_early_stopping(
                    gnn_first_period, train_loader, val_loader, args.n_classes, args,
                    epochs=100, patience=4, save_path=path_first, filename=filename_first
                )

                # 关键！训练结束后立刻从 save_path 回载“最优”到内存
                print(f"[阶段一] 从最优权重回载到内存: {path_first}")
                gnn_first_period = load_best_model_into(gnn_first_period, path_first, device)

            # 如果阶段二只用阶段一当固定特征提取器——建议冻结
            maybe_freeze(gnn_first_period, freeze=True)

            # === 阶段二 ===
            print(
                f"[阶段二] 训练 GNN：n_classes={args.n_classes}, layers={args.num_layers_second}, J={args.J_second}, num_features={args.num_features_second}")

            if os.path.exists(path_second):
                print(f"[阶段二] 检测到已有模型文件，直接载入: {path_second}")
                gnn_second_period = torch.load(path_second, map_location=device)
            else:
                if args.generative_model == 'SBM_multiclass':
                    gnn_second_period = GNN_multiclass_second_period(
                        args.num_features_second, args.num_layers_second, args.J_second + 2, n_classes=args.n_classes
                    )
                if torch.cuda.is_available():
                    gnn_second_period = gnn_second_period.to(device)

                loss_list, acc_list = train_second_period_with_early_stopping(
                    gnn_first_period, gnn_second_period, train_loader, val_loader,
                    args.n_classes, args,
                    epochs=100, patience=4, save_path=path_second, filename=filename_second
                )

                print(f"[阶段二] 保存模型到 {path_second}")

        print("[测试阶段] 开始...")
        print("训练逻辑回归模型...")
        lr_model = train_logistic_regression(gen, args.n_classes)

        gnn_first_period = torch.load(path_first, map_location=torch.device('cpu'))
        gnn_second_period = torch.load(path_second, map_location=torch.device('cpu'))

        # W, labels, eigvecs_top, snr, class_sizes = gen.random_sample_otf_single()
        # print(W.shape)
        # W = W.squeeze(0)  # 从 (1,1000,1000) → (1000,1000)
        #
        # loss, value = test_single_local_refinement(gnn_first_period, gnn_second_period, W, labels, args.n_classes, args, iter)
        #

        print("[测试阶段] 开始...")

        class_sizes_list = [
            [250, 250, 250, 250],  # 完全平衡

            [400, 200, 200, 200],  # 一类占比大

            [550, 150, 150, 150],  # 更不平衡

            [700, 100, 100, 100],  # 极度不平衡

            [850, 50, 50, 50],  # 超级极端
        ]

        snr_list = [0.1, 0.25, 0.5, 0.75, 1, 1.5, 2]

        total_ab_list = [5, 10, 15]  # ✅ 把 total_ab 也放到一个列表里
        # total_ab_list = [5]  # ✅ 把 total_ab 也放到一个列表里

        N = 1000

        logN_div_N = np.log(N) / N  # ≈ 0.0069

        # 构造所有任务

        task_args = []

        for total_ab in total_ab_list:

            for class_sizes in class_sizes_list:

                for snr in snr_list:
                    task_args.append((gnn_first_period, gnn_second_period, lr_model, class_sizes, snr, gen,
                                      logN_div_N, total_ab))

        # 并行执行

        print("[测试阶段] 并行执行中...")

        with Pool(processes=os.cpu_count() // 2) as pool:  # nodes 可调节，比如 4~8

            results = pool.map(test_first_period_wrapper, task_args)

        # 统一写到不同的 Excel 文件（根据 total_ab 区分）
        for (row_acc, row_ari, row_nmi), (
        gnn_fp, gnn_fp_second, lr_model, class_sizes, snr, gen, logN_div_N, total_ab) in zip(results,
                                                                                             task_args):
            filename = f"summary_total{total_ab}.xlsx"  # ✅ 每个 total_ab 一个文件
            append_rows_to_excel(
                row_acc, row_ari, row_nmi,
                filename=filename,
                extra_info={"class_sizes": class_sizes, "snr": snr, "total_ab": total_ab}
            )

    except Exception as e:

        import traceback

        traceback.print_exc()