import argparse
import multiprocessing
import os
import sys
import time
from multiprocessing import Pool

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from GNN_local_refinement import run_refinement_chain
from data_generator_lsm import Generator, simple_collate_fn
from load import get_gnn_inputs
from losses import compute_loss_multiclass, compute_acc_ari_nmi, \
    gnn_compute_acc_ari_nmi_multiclass
from losses import from_scores_to_labels_multiclass_batch
from models import GNN_multiclass, GNN_multiclass_second_period
from train_first_period import train_first_period_with_early_stopping
from train_second_period import train_second_period_with_early_stopping


def setup_logger(prefix="main_gnn"):
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    pid = os.getpid()
    log_filename = f"{prefix}_{timestamp}_{pid}.log"
    logfile = open(log_filename, "w", buffering=1)
    sys.stdout = logfile
    sys.stderr = logfile
    print(f"[Logger initialized] Logging to: {log_filename}")

def load_best_model_into(model, ckpt_path, device):
    ckpt = torch.load(ckpt_path, map_location=device,weights_only = False)
    if isinstance(ckpt, dict) and 'model_state' in ckpt:
        model.load_state_dict(ckpt['model_state'], strict=False)
        return model
    elif hasattr(ckpt, 'state_dict'):
        try:
            model.load_state_dict(ckpt.state_dict(), strict=False)
            return model
        except Exception:
            ckpt = ckpt.to(device)
            return ckpt
    else:
        model.load_state_dict(ckpt, strict=False)
        return model

def maybe_freeze(model, freeze=True):
    if freeze:
        for p in model.parameters():
            p.requires_grad = False
    model.eval()

parser = argparse.ArgumentParser()

###############################################################################
#                             General Settings                                #
#                          提前配置参数，方便后面使用                              #
###############################################################################

parser.add_argument('--num_examples_train', nargs='?', const=1, type=int,
                    default=int(6000))
parser.add_argument('--num_examples_test', nargs='?', const=1, type=int,
                    default=int(30))
parser.add_argument('--num_examples_val', nargs='?', const=1, type=int,
                    default=int(1000))
parser.add_argument('--edge_density', nargs='?', const=1, type=float,
                    default=0.2)
parser.add_argument('--p_SBM', nargs='?', const=1, type=float,
                    default=0.2)
parser.add_argument('--q_SBM', nargs='?', const=1, type=float,
                    default=0.1)
parser.add_argument('--class_sizes', type=int, nargs='+', default=[100, 1000],
                    help='List of class sizes for imbalanced SBM')

# 在 parser.add_argument 部分添加 LSM 参数
parser.add_argument('--lsm_tau', nargs='?', const=1, type=float, default=0.5)
parser.add_argument('--lsm_gamma', nargs='?', const=1, type=float, default=1.0)
parser.add_argument('--lsm_C', nargs='?', const=1, type=float, default=20.0)
parser.add_argument('--lsm_norm_mu', nargs='?', const=1, type=float, default=0.5)
parser.add_argument('--lsm_alpha_std', nargs='?', const=1, type=float, default=0.0)
# parser.add_argument('--lsm_latent_dim', nargs='?', const=1, type=int, default=4)
# parser.add_argument('--lsm_radii', nargs='+', type=float, default=[0.8, 0.9,1, 1.1])

parser.add_argument('--random_noise', action='store_true')
parser.add_argument('--noise', nargs='?', const=1, type=float, default=2)
parser.add_argument('--noise_model', nargs='?', const=1, type=int, default=2)
#########################
#parser.add_argument('--generative_model', nargs='?', const=1, type=str,
#                    default='ErdosRenyi')
parser.add_argument('--generative_model', nargs='?', const=1, type=str,
                    default='SBM_multiclass')
parser.add_argument('--batch_size', nargs='?', const=1, type=int, default= 1)
parser.add_argument('--mode', nargs='?', const=1, type=str, default='test')
default_path = os.path.join(os.path.dirname(os.path.abspath(__file__)))
parser.add_argument('--mode_isbalanced', nargs='?', const=1, type=str, default='imbalanced')
parser.add_argument('--path_gnn', nargs='?', const=1, type=str, default=default_path)
parser.add_argument('--path_local_refinement', nargs='?', const=1, type=str, default='')

parser.add_argument('--filename_existing_gnn', nargs='?', const=1, type=str, default='')
parser.add_argument('--filename_existing_gnn_local_refinement', nargs='?', const=1, type=str, default='')

parser.add_argument('--print_freq', nargs='?', const=1, type=int, default=1)
parser.add_argument('--test_freq', nargs='?', const=1, type=int, default=500)
parser.add_argument('--save_freq', nargs='?', const=1, type=int, default=2000)
parser.add_argument('--clip_grad_norm', nargs='?', const=1, type=float,
                    default=10.0)
parser.add_argument('--freeze_bn', dest='eval_vs_train', action='store_true')
parser.set_defaults(eval_vs_train=True)

###############################################################################
#                                 GNN Settings                                #
###############################################################################

###############################################################################
#                                 GNN first period                            #
###############################################################################
parser.add_argument('--num_features', nargs='?', const=1, type=int,
                    default=16)
parser.add_argument('--num_layers', nargs='?', const=1, type=int,
                    default=30)
parser.add_argument('--J', nargs='?', const=1, type=int, default= 2)

###############################################################################
#                                 GNN second period                            #
###############################################################################
parser.add_argument('--num_features_second', nargs='?', const=1, type=int,
                    default= 8)
parser.add_argument('--num_layers_second', nargs='?', const=1, type=int,
                    default= 3)
parser.add_argument('--J_second', nargs='?', const=1, type=int, default= 1)

parser.add_argument('--n_classes', nargs='?', const=1, type=int,
                    default=2)
parser.add_argument('--N_train', nargs='?', const=1, type=int, default=1000)
parser.add_argument('--N_test', nargs='?', const=1, type=int, default=1000)
parser.add_argument('--N_val', nargs='?', const=1, type=int, default=1000)

parser.add_argument('--lr', nargs='?', const=1, type=float, default=4e-3)

args = parser.parse_args()

if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
    dtype_l = torch.cuda.LongTensor
else:
    dtype = torch.FloatTensor
    dtype_l = torch.LongTensor

batch_size = args.batch_size
criterion = nn.CrossEntropyLoss()
template1 = '{:<10} {:<10} {:<10} {:<15} {:<10} {:<10} {:<10}'
template2 = '{:<10} {:<10.5f} {:<10.5f} {:<15} {:<10} {:<10} {:<10.3f} \n'

template_header = '{:<6} {:<10} {:<10} {:<13} {:<10} {:<8} {:<10} {:<10} {:<20}'
template_row    = '{:<6} {:<10.5f} {:<10.5f} {:<13} {:<10} {:<8} {:<10.3f} {:<10.4f} {:<20}'

class SBMDataset(torch.utils.data.Dataset):
    def __init__(self, npz_file_list):
        self.files = npz_file_list
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        data = np.load(self.files[idx])
        adj = csr_matrix((data['adj_data'], data['adj_indices'], data['adj_indptr']), shape=tuple(data['adj_shape']))
        labels = data['labels']
        return {'adj': adj, 'labels': labels, 'num_nodes': adj.shape[0]}

# Sparse helpers
from scipy.sparse import csr_matrix
from spectral_clustering import local_refinement_by_neighbors

import pandas as pd
import numpy as np

def get_available_device():
    for i in range(torch.cuda.device_count()):
        try:
            torch.cuda.set_device(i)
            torch.zeros(1).cuda()
            return torch.device(f"cuda:{i}")
        except RuntimeError:
            continue
    return torch.device("cpu")

device = get_available_device()

def test_single_first_period(gnn_first_period, gnn_second_period, gen, n_classes, args, iter, mode='balanced', class_sizes=None, C=None, norm_mu=None, radii=None):
    import time, numpy as np
    from torch import Tensor

    def _sync():
        if torch.cuda.is_available():
            torch.cuda.synchronize()

    t_all0 = time.perf_counter()
    gnn_first_period.train()
    gnn_second_period.train()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ===== 1) 生成图 =====
    if mode == 'imbalanced':
        W, true_labels, B = gen.gen_one_lsm(
            class_sizes,
            C=C,  # 使用传入的C
            norm_mu=norm_mu,  # 使用传入的norm_mu
            radii=radii,
            is_training=True,
            iter = iter,
            # cuda=True
        )
        p_lsm = np.mean(np.diag(B))  # 对角线均值
        q_lsm = (np.sum(B) - np.sum(np.diag(B))) / (B.shape[0] * (B.shape[1] - 1))  # 非对角线均值
        N = args.N_test
        SNR = (p_lsm-q_lsm)**2*N/(2*(p_lsm+q_lsm))/np.log(N)
        true_labels = true_labels.type(dtype_l)
        lsm_params = {
            "p_lsm": float(p_lsm),
            "q_lsm": float(q_lsm),
            "SNR": float(SNR)
        }
        # print(f"DEBUG: Generated B matrix:\n{B}")
        # print(f"DEBUG: p_lsm={np.mean(np.diag(B))}, q_lsm={(np.sum(B) - np.sum(np.diag(B))) / (B.shape[0] * (B.shape[1] - 1))}")
    # else:
    #     W, true_labels, eigvecs_top = gen.sample_otf_single(is_training=True, cuda=True)
    #     true_labels = true_labels.type(dtype_l)

    # # ===== 2) 谱聚类 =====
    # res_spectral = spectral_clustering_adj(
    #     W, n_classes, true_labels,
    #     normalized=True,
    #     run_all=True,
    #     random_state=0
    # )

    # ===== 3) 构造 GNN 输入 =====
    W = W.to(device)  # shape: (B, N, N)
    W_np = W.detach().cpu().numpy() if torch.is_tensor(W) else np.asarray(W)

    WW, x = get_gnn_inputs(W.detach().cpu().numpy(), args.J)   # 输出：WW: (B, N, N, J+3), x: (B, N, d)
    WW = WW.clone().detach().to(torch.float32).to(device)
    x = x.clone().detach().to(torch.float32).to(device)

    # ===== 4) 第一阶段 GNN 前向 =====
    with torch.no_grad():
        _sync()
        pred_single_first = gnn_first_period(WW, x)   # (B, N, n_classes)
        _sync()
        start_x_label = from_scores_to_labels_multiclass_batch(
            pred_single_first.detach().cpu().numpy()
        )

    # ===== 5) Refinement chain =====
    with torch.no_grad():
        ref = run_refinement_chain(
            gnn_second_period=gnn_second_period,
            W_np=W_np,                      # 你的实现接受 torch 或 numpy 都行
            init_labels=start_x_label,
            true_labels=true_labels,
            args=args,
            device=device,
            total_iters=10,
            verbose=False
        )
        _sync()

    if ref.get("first_iter") is not None:
        gnn_pred_label_second = ref["first_iter"]["pred_label"]
        pred_single_second    = ref["first_iter"]["pred"]
        acc_gnn_second        = ref["first_iter"]["acc"]
        ari_gnn_second        = ref["first_iter"]["ari"]
        nmi_gnn_second        = ref["first_iter"]["nmi"]

    acc_gnn_final = ref["final"]["acc"]
    ari_gnn_final = ref["final"]["ari"]
    nmi_gnn_final = ref["final"]["nmi"]

    # ===== 6) loss & 指标 =====
    W_np = W.squeeze(0).detach().cpu().numpy() if isinstance(W, Tensor) else np.asarray(W).squeeze(0)

    loss_test_first = compute_loss_multiclass(pred_single_first, true_labels, n_classes)
    _sync()
    acc_gnn_first, best_matched_pred, ari_gnn_first, nmi_gnn_first = gnn_compute_acc_ari_nmi_multiclass(
        pred_single_first, true_labels, n_classes
    )
    gnn_pred_label_first = best_matched_pred

    # ===== 7) 邻居多数投票本地 refinement =====
    gnn_refined = local_refinement_by_neighbors(W_np, gnn_pred_label_first, n_classes)
    acc_gnn_refined, best_matched_pred, ari_gnn_refined, nmi_gnn_refined = compute_acc_ari_nmi(
        gnn_refined, true_labels, n_classes
    )
    gnn_refined_pred = best_matched_pred

    # ===== 8) 总结 & 打印（仅总时间）=====
    total_elapsed = time.perf_counter() - t_all0
    loss_value = float(loss_test_first.data.detach().cpu().numpy()) if torch.cuda.is_available() else float(loss_test_first.data.numpy())

    info = ['iter', 'avg loss', 'avg acc', 'edge_density', 'noise', 'model', 'elapsed']
    out  = [iter, loss_value, acc_gnn_first, args.edge_density, args.noise, 'GNN', total_elapsed]
    print(template1.format(*info))
    print(template2.format(*out))
    # 去掉了分阶段的 [TIMING] 打印

    del WW, x

    metrics = {
        "gnn": {"acc": float(acc_gnn_first), "ari": float(ari_gnn_first), "nmi": float(nmi_gnn_first)},
        "gnn_second": {"acc": float(acc_gnn_second), "ari": float(ari_gnn_second), "nmi": float(nmi_gnn_second)},
        "gnn_final":  {"acc": float(acc_gnn_final),  "ari": float(ari_gnn_final),  "nmi": float(nmi_gnn_final)},
        "gnn_refined": {"acc": float(acc_gnn_refined), "ari": float(ari_gnn_refined), "nmi": float(nmi_gnn_refined)},
    }

    return float(loss_test_first), metrics, lsm_params


METHODS = [
    "gnn", "gnn_second", "gnn_refined", "gnn_final",
]


def test_first_period(
    gnn_first_period,
    gnn_second_period,
    n_classes,
    gen,
    args,
    iters=None,
    mode='balanced',
    class_sizes=None,
    C=None,  # 添加这些参数
    norm_mu=None,
    radii=None
):
    if iters is None:
        iters = args.num_examples_test

    gnn_first_period.train()
    buckets = {m: {"acc": [], "ari": [], "nmi": []} for m in METHODS}
    lsm_params_list = []
    for it in range(iters):
        loss_val, metrics , lsm_params = test_single_first_period(
            gnn_first_period=gnn_first_period,
            gnn_second_period=gnn_second_period,
            gen=gen,
            n_classes=n_classes,
            args=args,
            iter=it,
            mode=mode,
            class_sizes=class_sizes,
            C=C,  # 传递这些参数
            norm_mu=norm_mu,
            radii=radii
        )

        lsm_params_list.append(lsm_params)
        for meth in METHODS:
            md = metrics.get(meth, {})
            for k in ("acc", "ari", "nmi"):
                v = md.get(k, np.nan)
                buckets[meth][k].append(float(v) if np.isfinite(v) else np.nan)

        torch.cuda.empty_cache()

    avg_lsm_params = {
        "p_lsm": np.nanmean([p["p_lsm"] for p in lsm_params_list]),
        "q_lsm": np.nanmean([p["q_lsm"] for p in lsm_params_list]),
        "SNR": np.nanmean([p["SNR"] for p in lsm_params_list])
    }
    def mean(vals):
        arr = np.asarray(vals, dtype=float)
        return float(np.nanmean(arr))

    n = getattr(args, "N_train", getattr(gen, "N_train", None)) or getattr(args, "N_test", getattr(gen, "N_test", -1))
    logn_div_n = np.log(n) / n if n and n > 0 else np.nan
    # a = gen.p_SBM / logn_div_n if np.isfinite(logn_div_n) else np.nan
    # b = gen.q_SBM / logn_div_n if np.isfinite(logn_div_n) else np.nan
    # k = n_classes
    # snr = (a - b) ** 2 / (k * (a + (k - 1) * b)) if np.all(np.isfinite([a, b])) else np.nan

    meta = {
        "n_classes": int(n_classes),
        "class_sizes": str(class_sizes) if class_sizes is not None else "",
        "p_LSM": float(avg_lsm_params["p_lsm"]),
        "q_LSM": float(avg_lsm_params["q_lsm"]),
        "J": int(args.J),
        "N_train": int(getattr(args, "N_train", getattr(gen, "N_train", -1))),
        "N_test": int(getattr(args, "N_test", getattr(gen, "N_test", -1))),
        "SNR": float(avg_lsm_params["SNR"]) if np.isfinite(avg_lsm_params["p_lsm"]) else np.nan,
    }

    row_acc = {**meta}
    row_ari = {**meta}
    row_nmi = {**meta}
    for meth in METHODS:
        row_acc[meth] = mean(buckets[meth]["acc"])
        row_ari[meth] = mean(buckets[meth]["ari"])
        row_nmi[meth] = mean(buckets[meth]["nmi"])

    return row_acc, row_ari, row_nmi,avg_lsm_params

from pathlib import Path

def append_rows_to_excel(row_acc: dict, row_ari: dict, row_nmi: dict, filename="summary.xlsx", extra_info: dict=None):
    def _normalize_extra(ei: dict):
        if not ei:
            return {}
        ei = dict(ei)
        if "class_sizes" in ei:
            cs = ei["class_sizes"]
            if isinstance(cs, (list, tuple, np.ndarray)):
                ei["class_sizes"] = "-".join(map(str, cs))
            else:
                ei["class_sizes"] = str(cs)
        return ei

    def _append_row(sheet_name: str, row: dict):
        merged = dict(row)
        merged.update(_normalize_extra(extra_info))
        df_new = pd.DataFrame([merged])
        path = Path(filename)

        if path.exists():
            try:
                df_old = pd.read_excel(path, sheet_name=sheet_name)
                cols = list(dict.fromkeys(list(df_old.columns) + list(df_new.columns)))
                df_old = df_old.reindex(columns=cols)
                df_new = df_new.reindex(columns=cols)
                df_out = pd.concat([df_old, df_new], ignore_index=True)
            except ValueError:
                df_out = df_new
            with pd.ExcelWriter(filename, engine="openpyxl", mode="a", if_sheet_exists="replace") as writer:
                df_out.to_excel(writer, sheet_name=sheet_name, index=False)
        else:
            with pd.ExcelWriter(filename, engine="openpyxl", mode="w") as writer:
                df_new.to_excel(writer, sheet_name=sheet_name, index=False)

    _append_row("ACC", row_acc)
    _append_row("ARI", row_ari)
    _append_row("NMI", row_nmi)

def test_first_period_wrapper(args_tuple):
    gnn_first_period, gnn_second_period, class_sizes, norm_mu,radii,  gen, logN_div_N, C = args_tuple

    # 创建临时的 gen 副本并设置 LSM 参数
    gen_local = gen.copy()
    gen_local.lsm_C = C
    gen_local.lsm_norm_mu = norm_mu
    gen_local.lsm_radii = radii

    print(f"\n[测试阶段] class_sizes: {class_sizes}, C: {C}, norm_mu: {norm_mu:.3f}")
    print(f"LSM 参数: radii={radii}")

    row_acc, row_ari, row_nmi , avg_lsm_params= test_first_period(
        gnn_first_period=gnn_first_period,
        gnn_second_period=gnn_second_period,
        n_classes=args.n_classes,
        gen=gen_local,  # 使用设置了 LSM 参数的生成器
        args=args,
        iters=args.num_examples_test,
        mode=args.mode_isbalanced,
        class_sizes=class_sizes,
        C=C,  # 传递这些参数
        norm_mu=norm_mu,
        radii=radii
    )

    # print(f"\n[测试阶段] class_sizes: {class_sizes}, SNR: {snr:.2f}")
    # print(f"使用的 SBM 参数: p={p_SBM}, q={q_SBM}")
    #
    # row_acc, row_ari, row_nmi = test_first_period(
    #     gnn_first_period=gnn_first_period,
    #     gnn_second_period=gnn_second_period,
    #     n_classes=args.n_classes,
    #     gen=gen_local,
    #     args=args,
    #     iters=args.num_examples_test,
    #     mode=args.mode_isbalanced,
    #     class_sizes=class_sizes,
    # )
    # # total_ab这些要传吗

    return row_acc, row_ari, row_nmi , avg_lsm_params

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


if __name__ == '__main__':
    try:
        setup_logger("run_gnn")
        gen = Generator()
        gen.N_train = args.N_train
        gen.N_test = args.N_test
        gen.N_val = args.N_val
        gen.edge_density = args.edge_density
        gen.p_SBM = args.p_SBM
        gen.q_SBM = args.q_SBM
        gen.random_noise = args.random_noise
        gen.noise = args.noise
        gen.noise_model = args.noise_model
        gen.generative_model = args.generative_model
        gen.n_classes = args.n_classes
        gen.num_examples_train = args.num_examples_train
        gen.num_examples_test = args.num_examples_test
        gen.num_examples_val = args.num_examples_val

        root_model_dir = "model_GNN"
        os.makedirs(root_model_dir, exist_ok=True)
        folder_name = f"GNN_model_first_classes{args.n_classes}"
        full_save_dir = os.path.join(root_model_dir, folder_name)
        os.makedirs(full_save_dir, exist_ok=True)

        filename_first = f'gnn_J{args.J}_lyr{args.num_layers}_classes{args.n_classes}_numfeatures{args.num_features}'
        path_first = os.path.join(full_save_dir, filename_first)

        filename_second = f'local+refin_gnn_J{args.J_second}_lyr{args.num_layers_second}_classes{args.n_classes}_numfeatures{args.num_features_second}'
        path_second = os.path.join(full_save_dir, filename_second)

        if args.mode == "train":
            gen.prepare_data()

            num_workers = min(4, multiprocessing.cpu_count() - 1)
            print("num_workers", num_workers)

            train_dataset = SBMDataset(gen.data_train)
            val_dataset = SBMDataset(gen.data_val)
            test_dataset = SBMDataset(gen.data_test)

            train_loader = DataLoader(
                train_dataset, batch_size=args.batch_size, shuffle=True,
                num_workers=num_workers, collate_fn=simple_collate_fn
            )
            val_loader = DataLoader(
                val_dataset, batch_size=args.batch_size, shuffle=False,
                num_workers=num_workers, collate_fn=simple_collate_fn
            )
            test_loader = DataLoader(
                test_dataset, batch_size=args.batch_size, shuffle=False,
                num_workers=num_workers, collate_fn=simple_collate_fn
            )

            print(f"[阶段一] 训练 GNN：n_classes={args.n_classes}, layers={args.num_layers}, J={args.J}, num_features={args.num_features}")
            if os.path.exists(path_first):
                print(f"[阶段一] 检测到已有‘最优权重’，直接载入: {path_first}")
                if args.generative_model == 'SBM_multiclass':
                    gnn_first_period = GNN_multiclass(args.num_features, args.num_layers, args.J + 3, n_classes=args.n_classes)
                gnn_first_period = gnn_first_period.to(device)
                gnn_first_period = load_best_model_into(gnn_first_period, path_first, device)
            else:
                if args.generative_model == 'SBM_multiclass':
                    gnn_first_period = GNN_multiclass(args.num_features, args.num_layers, args.J + 3, n_classes=args.n_classes)
                if torch.cuda.is_available():
                    gnn_first_period = gnn_first_period.to(device)

                loss_list, acc_list = train_first_period_with_early_stopping(
                    gnn_first_period, train_loader, val_loader, args.n_classes, args,
                    epochs=20, save_path=path_first, filename=filename_first
                )
                print(f"[阶段一] 从最优权重回载到内存: {path_first}")
                gnn_first_period = load_best_model_into(gnn_first_period, path_first, device)

            maybe_freeze(gnn_first_period, freeze=True)

            print(f"[阶段二] 训练 GNN：n_classes={args.n_classes}, layers={args.num_layers_second}, J={args.J_second}, num_features={args.num_features_second}")
            if os.path.exists(path_second):
                print(f"[阶段二] 检测到已有模型文件，直接载入: {path_second}")
                gnn_second_period = torch.load(path_second, map_location=device,weights_only = False)
                gnn_second_period = gnn_second_period.to(device)
            else:
                if args.generative_model == 'SBM_multiclass':
                    gnn_second_period = GNN_multiclass_second_period(
                        args.num_features_second, args.num_layers_second, args.J_second + 3, n_classes=args.n_classes
                    )
                if torch.cuda.is_available():
                    gnn_second_period = gnn_second_period.to(device)

                loss_list, acc_list = train_second_period_with_early_stopping(
                    gnn_first_period, gnn_second_period, train_loader, val_loader,
                    args.n_classes, args,
                    epochs=20,  save_path=path_second, filename=filename_second
                )
                print(f"[阶段二] 保存模型到 {path_second}")

        print("[测试阶段] 开始...")

        # 模型加载到当前 device
        gnn_first_period = torch.load(path_first, map_location=device,weights_only = False)
        gnn_first_period = gnn_first_period.to(device)
        gnn_second_period = torch.load(path_second, map_location=device,weights_only = False)
        gnn_second_period = gnn_second_period.to(device)

        print("[测试阶段] 开始...")
        class_sizes_dict = {
            2: [
                [500, 500],  # 平衡
                [600, 400],  # 轻度不平衡
                [700, 300],  # 中度不平衡
                [800, 200],

            ],

            4: [
                [250, 250, 250, 250],  # 平衡
                [300, 250, 250, 200],  # 轻度不平衡
                [400, 300, 200, 100],  # 中度不平衡
                [700, 100, 100, 100]  # 极端不平衡
            ],

            8: [
                [125, 125, 125, 125, 125, 125, 125, 125],  # 平衡
                [150, 125, 125, 125, 125, 125, 125, 100],  # 轻度不平衡
                [200, 180, 160, 140, 120, 100, 80, 20],  # 中度不平衡
                [650, 50, 50, 50, 50, 50, 50, 50]  # 极端不平衡
            ],
        }

        total_ab_dict = {
            2: [5, 10, 15],

            4: [15, 20, 25],

            8: [25, 30, 35]
        }

        class_sizes_list = class_sizes_dict[args.n_classes]
        total_ab_list = total_ab_dict[args.n_classes]

        K = args.n_classes
        if K == 2:
            radii_list = [ [1.0, 1.2] ]
        elif K == 4:
            radii_list = [ [0.8, 0.9, 1.0, 1.1] ]
        elif K == 8:
            radii_list = [ [round(0.8 + i * 0.05, 2) for i in range(9)] ] # 0.8到1.2，间隔0.05

        # p相同和p不同
        N = 1000
        logN_div_N = np.log(N) / N

        task_args = []


        def precompute_scalings(K, n, C, max_r, snr_list):
            """半径全为1时预计算缩放系数，输入SNR列表，返回对应的缩放系数列表"""
            from scipy.special import expit
            abar = 0.5 * np.log(C * np.log(n) / n)
            scaling_list = []
            for target_snr in snr_list:
                def compute_snr(s):
                    p = expit(2 * abar + max_r ** 2 * s)
                    q = expit(2 * abar - 1 / (K - 1) * max_r ** 2 * s)
                    return (p - q) ** 2 * n / (K * (p + (K - 1) * q) * np.log(n))

                # 搜索区间 [0.01, 100]
                grid = np.logspace(np.log10(0.01), np.log10(10), 100)
                snrs = np.array([compute_snr(s) for s in grid])
                idx = np.argmin(np.abs(snrs - target_snr))
                scaling = grid[idx]
                scaling_list.append(scaling)

            return scaling_list


        snr_list = [0.25, 0.5, 0.75, 1, 1.5]
        for t_idx,total_ab in enumerate(total_ab_list):
            norm_mu_grid = precompute_scalings(K, N, total_ab, 1, snr_list)
            for class_sizes in class_sizes_list:
                for radii in radii_list:
                    for norm_mu in norm_mu_grid:
                        task_args.append((gnn_first_period, gnn_second_period, class_sizes, norm_mu,radii,  gen,
                                      logN_div_N, total_ab))

        # 并行执行

        print("[测试阶段] 并行执行中...")

        with Pool(processes=os.cpu_count() // 2) as pool:  # nodes 可调节，比如 4~8
            results = pool.map(test_first_period_wrapper, task_args)

        # 创建 results/lsm 文件夹
        lsm_dir = os.path.join("results", "lsm")
        os.makedirs(lsm_dir, exist_ok=True)

        # results 的一项返回 (row_acc, row_ari, row_nmi, avg_lsm_params)
        for (row_acc, row_ari, row_nmi, avg_lsm_params), (
                gnn_fp, gnn_fp_second, class_sizes, norm_mu, radii, gen, logN_div_N, C
        ) in zip(results, task_args):
            filename = os.path.join(lsm_dir, f"summary_C{C}.xlsx")  # ✅ 保存到 results/lsm
            append_rows_to_excel(
                row_acc, row_ari, row_nmi,
                filename=filename,
                extra_info={
                    "class_sizes": class_sizes,
                    "C": C,
                    "norm_mu": norm_mu,
                    "radii": str(radii),
                    "p_lsm": avg_lsm_params["p_lsm"],
                    "q_lsm": avg_lsm_params["q_lsm"],
                    "SNR": avg_lsm_params["SNR"],
                }
            )

    except Exception as e:
        import traceback
        traceback.print_exc()
