import argparse
import multiprocessing
import os
import random
import sys
import time

import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm

from GNN_local_refinement import run_refinement_chain
from controlsnr import find_a_given_snr
from data_generator_dcsbm import Generator, simple_collate_fn
from load import get_gnn_inputs
from losses import compute_loss_multiclass, compute_acc_ari_nmi, \
    gnn_compute_acc_ari_nmi_multiclass
from losses import from_scores_to_labels_multiclass_batch
from models import GNN_multiclass, GNN_multiclass_second_period
from train_first_period import train_first_period_with_early_stopping
from train_second_period import train_second_period_with_early_stopping


def setup_logger(prefix="main_gnn"):
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    pid = os.getpid()
    log_filename = f"{prefix}_{timestamp}_{pid}.log"
    logfile = open(log_filename, "w", buffering=1)
    sys.stdout = logfile
    sys.stderr = logfile
    print(f"[Logger initialized] Logging to: {log_filename}")

def load_best_model_into(model, ckpt_path, device):
    ckpt = torch.load(ckpt_path, map_location=device)
    if isinstance(ckpt, dict) and 'model_state' in ckpt:
        model.load_state_dict(ckpt['model_state'], strict=False)
        return model
    elif hasattr(ckpt, 'state_dict'):
        try:
            model.load_state_dict(ckpt.state_dict(), strict=False)
            return model
        except Exception:
            ckpt = ckpt.to(device)
            return ckpt
    else:
        model.load_state_dict(ckpt, strict=False)
        return model

def maybe_freeze(model, freeze=True):
    if freeze:
        for p in model.parameters():
            p.requires_grad = False
    model.eval()

parser = argparse.ArgumentParser()

###############################################################################
#                             General Settings                                #
#                          提前配置参数，方便后面使用                              #
###############################################################################

parser.add_argument('--num_examples_train', nargs='?', const=1, type=int,
                    default=int(6000))
parser.add_argument('--num_examples_test', nargs='?', const=1, type=int,
                    default=int(30))
parser.add_argument('--num_examples_val', nargs='?', const=1, type=int,
                    default=int(1000))
parser.add_argument('--edge_density', nargs='?', const=1, type=float,
                    default=0.2)
parser.add_argument('--p_SBM', nargs='?', const=1, type=float,
                    default=0.2)
parser.add_argument('--q_SBM', nargs='?', const=1, type=float,
                    default=0.1)
parser.add_argument('--class_sizes', type=int, nargs='+', default=[100, 1000],
                    help='List of class sizes for imbalanced SBM')
parser.add_argument('--random_noise', action='store_true')
parser.add_argument('--noise', nargs='?', const=1, type=float, default=2)
parser.add_argument('--noise_model', nargs='?', const=1, type=int, default=2)
#########################
#parser.add_argument('--generative_model', nargs='?', const=1, type=str,
#                    default='ErdosRenyi')
parser.add_argument('--generative_model', nargs='?', const=1, type=str,
                    default='SBM_multiclass')
parser.add_argument('--batch_size', nargs='?', const=1, type=int, default= 1)
parser.add_argument('--mode', nargs='?', const=1, type=str, default='test')
default_path = os.path.join(os.path.dirname(os.path.abspath(__file__)))
parser.add_argument('--mode_isbalanced', nargs='?', const=1, type=str, default='imbalanced')
parser.add_argument('--path_gnn', nargs='?', const=1, type=str, default=default_path)
parser.add_argument('--path_local_refinement', nargs='?', const=1, type=str, default='')

parser.add_argument('--filename_existing_gnn', nargs='?', const=1, type=str, default='')
parser.add_argument('--filename_existing_gnn_local_refinement', nargs='?', const=1, type=str, default='')

parser.add_argument('--print_freq', nargs='?', const=1, type=int, default=1)
parser.add_argument('--test_freq', nargs='?', const=1, type=int, default=500)
parser.add_argument('--save_freq', nargs='?', const=1, type=int, default=2000)
parser.add_argument('--clip_grad_norm', nargs='?', const=1, type=float,
                    default=10.0)
parser.add_argument('--freeze_bn', dest='eval_vs_train', action='store_true')
parser.set_defaults(eval_vs_train=True)

###############################################################################
#                                 GNN Settings                                #
###############################################################################

###############################################################################
#                                 GNN first period                            #
###############################################################################
parser.add_argument('--num_features', nargs='?', const=1, type=int,
                    default=16)
parser.add_argument('--num_layers', nargs='?', const=1, type=int,
                    default=30)
parser.add_argument('--J', nargs='?', const=1, type=int, default= 2)

###############################################################################
#                                 GNN second period                            #
###############################################################################
parser.add_argument('--num_features_second', nargs='?', const=1, type=int,
                    default= 8)
parser.add_argument('--num_layers_second', nargs='?', const=1, type=int,
                    default= 3)
parser.add_argument('--J_second', nargs='?', const=1, type=int, default= 1)

parser.add_argument('--n_classes', nargs='?', const=1, type=int,
                    default=2)
parser.add_argument('--N_train', nargs='?', const=1, type=int, default=1000)
parser.add_argument('--N_test', nargs='?', const=1, type=int, default=1000)
parser.add_argument('--N_val', nargs='?', const=1, type=int, default=1000)

parser.add_argument('--lr', nargs='?', const=1, type=float, default=4e-3)

args = parser.parse_args()

if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
    dtype_l = torch.cuda.LongTensor
else:
    dtype = torch.FloatTensor
    dtype_l = torch.LongTensor

batch_size = args.batch_size
criterion = nn.CrossEntropyLoss()
template1 = '{:<10} {:<10} {:<10} {:<15} {:<10} {:<10} {:<10}'
template2 = '{:<10} {:<10.5f} {:<10.5f} {:<15} {:<10} {:<10} {:<10.3f} \n'

template_header = '{:<6} {:<10} {:<10} {:<13} {:<10} {:<8} {:<10} {:<10} {:<20}'
template_row    = '{:<6} {:<10.5f} {:<10.5f} {:<13} {:<10} {:<8} {:<10.3f} {:<10.4f} {:<20}'

class SBMDataset(torch.utils.data.Dataset):
    def __init__(self, npz_file_list):
        self.files = npz_file_list
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        data = np.load(self.files[idx])
        adj = csr_matrix((data['adj_data'], data['adj_indices'], data['adj_indptr']), shape=tuple(data['adj_shape']))
        labels = data['labels']
        return {'adj': adj, 'labels': labels, 'num_nodes': adj.shape[0]}

# Sparse helpers
from scipy.sparse import csr_matrix
from spectral_clustering import local_refinement_by_neighbors, local_refinement_by_neighbors_multi

import pandas as pd
import numpy as np

def get_available_device():
    for i in range(torch.cuda.device_count()):
        try:
            torch.cuda.set_device(i)
            torch.zeros(1).cuda()
            return torch.device(f"cuda:{i}")
        except RuntimeError:
            continue
    return torch.device("cpu")

device = get_available_device()

def test_both_models_first_period(
    gnn_first_period, gnn_second_period, gen, n_classes, args, iter,
    class_sizes=None, print_each=True
):
    """
    一次测试同时跑 SBM 和 DCBM：
      - 使用 gen.imbalanced_sample_otf_single（SBM）
      - 使用 gen.imbalanced_dcsbm_sample_otf_single（DCBM）
    返回：
      avg_loss: 两次 loss 的平均
      avg_metrics: 指标字典的逐项平均（gnn / gnn_second / gnn_final / gnn_refined 下的 acc/ari/nmi 平均）
      details: {'sbm': {...}, 'dcbm': {...}} 分别是两次完整返回（loss, metrics）
    """

    def _sync():
        if torch.cuda.is_available():
            torch.cuda.synchronize()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    gnn_first_period.train()
    gnn_second_period.train()

    # ------- 将你单模态的流程抽成一个内部 helper，传入采样函数名 -------
    def _run_one(sample_method_name):
        t_all0 = time.perf_counter()

        # class_sizes 随机翻转逻辑（与你原来一致）
        if random.random() < 0.5 and class_sizes is not None:
            random_class_sizes = class_sizes[::-1]
        else:
            random_class_sizes = class_sizes

        # 取生成函数
        sample_fn = getattr(gen, sample_method_name)
        # 统一接口：W, true_labels, eigvecs_top
        W, true_labels = sample_fn(random_class_sizes, is_training=True, cuda=True)

        # 类型处理
        true_labels = true_labels.type(dtype_l)  # 保持你原有的全局 dtype_l

        # numpy / device 转换
        W_np = W.detach().cpu().numpy() if torch.is_tensor(W) else np.asarray(W)

        # GNN 输入
        WW, x = get_gnn_inputs(W_np, args.J)  # WW: (B,N,N,J+3), x: (B,N,d)
        WW = WW.clone().detach().to(torch.float32).to(device)
        x  = x.clone().detach().to(torch.float32).to(device)

        # ===== 第一阶段前向 =====
        with torch.no_grad():
            _sync()
            pred_single_first = gnn_first_period(WW, x)   # (B, N, n_classes)
            _sync()
            start_x_label = from_scores_to_labels_multiclass_batch(
                pred_single_first.detach().cpu().numpy()
            )

        # ===== refinement chain =====
        with torch.no_grad():
            ref = run_refinement_chain(
                gnn_second_period=gnn_second_period,
                W_np=W_np,
                init_labels=start_x_label,
                true_labels=true_labels,
                args=args,
                device=device,
                total_iters=10,
                verbose=False
            )
            _sync()

        # 取第二阶段与最终指标
        if ref.get("first_iter") is not None:
            acc_gnn_second = ref["first_iter"]["acc"]
            ari_gnn_second = ref["first_iter"]["ari"]
            nmi_gnn_second = ref["first_iter"]["nmi"]
        else:
            acc_gnn_second = ari_gnn_second = nmi_gnn_second = 0.0

        acc_gnn_final = ref["final"]["acc"]
        ari_gnn_final = ref["final"]["ari"]
        nmi_gnn_final = ref["final"]["nmi"]

        # ===== loss & 第一阶段指标 =====
        loss_test_first = compute_loss_multiclass(pred_single_first, true_labels, n_classes)

        acc_gnn_first, best_matched_pred, ari_gnn_first, nmi_gnn_first = gnn_compute_acc_ari_nmi_multiclass(
            pred_single_first, true_labels, n_classes
        )

        res_multi = local_refinement_by_neighbors_multi(
            A=W_np,
            init_labels=start_x_label,
            num_classes=n_classes,
            true_labels=true_labels,
            num_iters=1,
            alpha=1e-6,
            random_state=0,
            tol=0,
            verbose=False,
            return_history=True
        )

        # 第一轮（history[0]）
        first_iter = res_multi["history"][0]
        acc_gnn_refined = first_iter.get("acc", 0.0)
        ari_gnn_refined = first_iter.get("ari", 0.0)
        nmi_gnn_refined = first_iter.get("nmi", 0.0)

        total_elapsed = time.perf_counter() - t_all0
        loss_value = float(loss_test_first.data.detach().cpu().numpy()) if torch.cuda.is_available() else float(loss_test_first.data.numpy())

        if print_each:
            info = ['iter', 'avg loss', 'avg acc', 'edge_density', 'noise', 'model', 'elapsed']
            out  = [iter, loss_value, acc_gnn_first, args.edge_density, args.noise, sample_method_name, total_elapsed]
            print(template1.format(*info))
            print(template2.format(*out))

        # 清理中间张量
        del WW, x

        metrics = {
            "gnn":         {"acc": float(acc_gnn_first),   "ari": float(ari_gnn_first),   "nmi": float(nmi_gnn_first)},
            "gnn_second":  {"acc": float(acc_gnn_second),  "ari": float(ari_gnn_second),  "nmi": float(nmi_gnn_second)},
            "gnn_refined": {"acc": float(acc_gnn_refined), "ari": float(ari_gnn_refined), "nmi": float(nmi_gnn_refined)},
            "gnn_final": {"acc": float(acc_gnn_final), "ari": float(ari_gnn_final), "nmi": float(nmi_gnn_final)}
        }
        return loss_value, metrics

    # ------- 分别跑 SBM 与 DCBM -------
    loss_sbm,  metrics_sbm  = _run_one("imbalanced_sample_otf_single")
    loss_dcbm, metrics_dcbm = _run_one("imbalanced_dcsbm_sample_otf_single")

    # ------- 逐项求平均 -------
    def _avg_two(a, b):
        return 0.5 * (float(a) + float(b))

    def _avg_metrics_dict(m1, m2):
        out = {}
        for key in m1.keys():  # gnn / gnn_second / gnn_final / gnn_refined
            out[key] = {}
            for k2 in m1[key].keys():  # acc / ari / nmi
                out[key][k2] = _avg_two(m1[key][k2], m2[key][k2])
        return out

    avg_loss = _avg_two(loss_sbm, loss_dcbm)
    avg_metrics = _avg_metrics_dict(metrics_sbm, metrics_dcbm)

    details = {
        "sbm":  {"loss": loss_sbm,  "metrics": metrics_sbm},
        "dcbm": {"loss": loss_dcbm, "metrics": metrics_dcbm},
    }
    return avg_loss, avg_metrics, details


METHODS = [
    "gnn", "gnn_second", "gnn_refined","gnn_final"
]

def test_first_period_split(
    gnn_first_period,
    gnn_second_period,
    n_classes,
    gen,
    args,
    iters=None,
    mode='balanced',
    class_sizes=None,
):
    """
    复用 test_both_models_first_period 的单次测试，
    但在聚合阶段分开统计 sbm 与 dcbm 的指标，并各自返回 3 张表（ACC/ARI/NMI）。
    """
    if iters is None:
        iters = args.num_examples_test

    # 分开两个“桶”
    buckets = {
        "sbm":  {m: {"acc": [], "ari": [], "nmi": []} for m in METHODS},
        "dcbm": {m: {"acc": [], "ari": [], "nmi": []} for m in METHODS},
    }

    gnn_first_period.train()
    gnn_second_period.train()

    for it in range(iters):
        # 这里仍然一次同时跑 SBM 和 DCBM，但我们只用 details 分开聚合
        _, _, details = test_both_models_first_period(
            gnn_first_period=gnn_first_period,
            gnn_second_period=gnn_second_period,
            gen=gen,
            n_classes=n_classes,
            args=args,
            iter=it,
            class_sizes=class_sizes,
            print_each=True,   # 聚合多次时可少打印
        )

        for kind in ("sbm", "dcbm"):
            mdict = details[kind]["metrics"]  # {"gnn": {...}, "gnn_second": {...}, "gnn_refined": {...}}
            for meth in METHODS:
                mvals = mdict.get(meth, {})
                for k in ("acc", "ari", "nmi"):
                    v = mvals.get(k, np.nan)
                    buckets[kind][meth][k].append(float(v) if np.isfinite(v) else np.nan)

        torch.cuda.empty_cache()

    def mean(vals):
        arr = np.asarray(vals, dtype=float)
        return float(np.nanmean(arr))

    # ------- 元信息（和你原 test_first_period 保持一致） -------
    n = getattr(args, "N_train", getattr(gen, "N_train", None)) or getattr(args, "N_test", getattr(gen, "N_test", -1))
    logn_div_n = np.log(n) / n if n and n > 0 else np.nan
    a = gen.p_SBM / logn_div_n if np.isfinite(logn_div_n) else np.nan
    b = gen.q_SBM / logn_div_n if np.isfinite(logn_div_n) else np.nan
    k = n_classes
    snr = (a - b) ** 2 / (k * (a + (k - 1) * b)) if np.all(np.isfinite([a, b])) else np.nan

    meta = {
        "n_classes": int(n_classes),
        "class_sizes": str(class_sizes) if class_sizes is not None else "",
        "p_SBM": float(gen.p_SBM),
        "q_SBM": float(gen.q_SBM),
        "J": int(args.J),
        "N_train": int(getattr(args, "N_train", getattr(gen, "N_train", -1))),
        "N_test": int(getattr(args, "N_test", getattr(gen, "N_test", -1))),
        "SNR": float(snr) if np.isfinite(snr) else np.nan,
    }

    def build_rows(bucket):
        row_acc = {**meta}
        row_ari = {**meta}
        row_nmi = {**meta}
        for meth in METHODS:
            row_acc[meth] = mean(bucket[meth]["acc"])
            row_ari[meth] = mean(bucket[meth]["ari"])
            row_nmi[meth] = mean(bucket[meth]["nmi"])
        return row_acc, row_ari, row_nmi

    rows_sbm  = build_rows(buckets["sbm"])
    rows_dcbm = build_rows(buckets["dcbm"])
    return rows_sbm, rows_dcbm


from pathlib import Path

def append_rows_to_excel(row_acc: dict, row_ari: dict, row_nmi: dict, filename="summary.xlsx", extra_info: dict=None):
    def _normalize_extra(ei: dict):
        if not ei:
            return {}
        ei = dict(ei)
        if "class_sizes" in ei:
            cs = ei["class_sizes"]
            if isinstance(cs, (list, tuple, np.ndarray)):
                ei["class_sizes"] = "-".join(map(str, cs))
            else:
                ei["class_sizes"] = str(cs)
        return ei

    def _append_row(sheet_name: str, row: dict):
        merged = dict(row)
        merged.update(_normalize_extra(extra_info))
        df_new = pd.DataFrame([merged])
        path = Path(filename)

        if path.exists():
            try:
                df_old = pd.read_excel(path, sheet_name=sheet_name)
                cols = list(dict.fromkeys(list(df_old.columns) + list(df_new.columns)))
                df_old = df_old.reindex(columns=cols)
                df_new = df_new.reindex(columns=cols)
                df_out = pd.concat([df_old, df_new], ignore_index=True)
            except ValueError:
                df_out = df_new
            with pd.ExcelWriter(filename, engine="openpyxl", mode="a", if_sheet_exists="replace") as writer:
                df_out.to_excel(writer, sheet_name=sheet_name, index=False)
        else:
            with pd.ExcelWriter(filename, engine="openpyxl", mode="w") as writer:
                df_new.to_excel(writer, sheet_name=sheet_name, index=False)

    _append_row("ACC", row_acc)
    _append_row("ARI", row_ari)
    _append_row("NMI", row_nmi)

def test_first_period_wrapper(args_tuple):
    gnn_first_period, gnn_second_period, class_sizes, snr, gen, logN_div_N, total_ab = args_tuple
    a, b = find_a_given_snr(snr, args.n_classes ,total_ab)
    p_SBM = round(a * logN_div_N, 4)
    q_SBM = round(b * logN_div_N, 4)

    gen_local = gen.copy()
    gen_local.p_SBM = p_SBM
    gen_local.q_SBM = q_SBM

    print(f"\n[测试阶段] class_sizes: {class_sizes}, SNR: {snr:.2f}")
    print(f"使用的 SBM 参数: p={p_SBM}, q={q_SBM}")

    # 关键：调用“分开聚合”的版本
    (row_acc_sbm, row_ari_sbm, row_nmi_sbm), \
    (row_acc_dcbm, row_ari_dcbm, row_nmi_dcbm) = test_first_period_split(
        gnn_first_period=gnn_first_period,
        gnn_second_period=gnn_second_period,
        n_classes=args.n_classes,
        gen=gen_local,
        args=args,
        iters=args.num_examples_test,
        mode=args.mode_isbalanced,
        class_sizes=class_sizes,
    )

    # 返回两套
    return (
        (row_acc_sbm,  row_ari_sbm,  row_nmi_sbm),
        (row_acc_dcbm, row_ari_dcbm, row_nmi_dcbm),
    )

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

if __name__ == '__main__':
    try:
        setup_logger("run_gnn")
        gen = Generator()
        gen.N_train = args.N_train
        gen.N_test = args.N_test
        gen.N_val = args.N_val
        gen.edge_density = args.edge_density
        gen.p_SBM = args.p_SBM
        gen.q_SBM = args.q_SBM
        gen.random_noise = args.random_noise
        gen.noise = args.noise
        gen.noise_model = args.noise_model
        gen.generative_model = args.generative_model
        gen.n_classes = args.n_classes
        gen.num_examples_train = args.num_examples_train
        gen.num_examples_test = args.num_examples_test
        gen.num_examples_val = args.num_examples_val

        root_model_dir = "model_GNN"
        os.makedirs(root_model_dir, exist_ok=True)
        folder_name = f"GNN_model_first_classes{args.n_classes}"
        full_save_dir = os.path.join(root_model_dir, folder_name)
        os.makedirs(full_save_dir, exist_ok=True)

        filename_first = f'gnn_J{args.J}_lyr{args.num_layers}_classes{args.n_classes}_numfeatures{args.num_features}'
        path_first = os.path.join(full_save_dir, filename_first)

        filename_second = f'local+refin_gnn_J{args.J_second}_lyr{args.num_layers_second}_classes{args.n_classes}_numfeatures{args.num_features_second}'
        path_second = os.path.join(full_save_dir, filename_second)

        if args.mode == "train":
            gen.prepare_data()

            num_workers = min(4, multiprocessing.cpu_count() - 1)
            print("num_workers", num_workers)

            train_dataset = SBMDataset(gen.data_train)
            val_dataset = SBMDataset(gen.data_val)
            test_dataset = SBMDataset(gen.data_test)

            train_loader = DataLoader(
                train_dataset, batch_size=args.batch_size, shuffle=True,
                num_workers=num_workers, collate_fn=simple_collate_fn
            )
            val_loader = DataLoader(
                val_dataset, batch_size=args.batch_size, shuffle=False,
                num_workers=num_workers, collate_fn=simple_collate_fn
            )
            test_loader = DataLoader(
                test_dataset, batch_size=args.batch_size, shuffle=False,
                num_workers=num_workers, collate_fn=simple_collate_fn
            )

            print(f"[阶段一] 训练 GNN：n_classes={args.n_classes}, layers={args.num_layers}, J={args.J}, num_features={args.num_features}")
            if os.path.exists(path_first):
                print(f"[阶段一] 检测到已有‘最优权重’，直接载入: {path_first}")
                if args.generative_model == 'SBM_multiclass':
                    gnn_first_period = GNN_multiclass(args.num_features, args.num_layers, args.J + 3, n_classes=args.n_classes)
                gnn_first_period = gnn_first_period.to(device)
                gnn_first_period = load_best_model_into(gnn_first_period, path_first, device)
            else:
                if args.generative_model == 'SBM_multiclass':
                    gnn_first_period = GNN_multiclass(args.num_features, args.num_layers, args.J + 3, n_classes=args.n_classes)
                if torch.cuda.is_available():
                    gnn_first_period = gnn_first_period.to(device)

                loss_list, acc_list = train_first_period_with_early_stopping(
                    gnn_first_period, train_loader, val_loader, args.n_classes, args,
                    epochs=100, patience=5, save_path=path_first, filename=filename_first
                )
                print(f"[Phase 1] Backload from optimal weight to memory: {path_first}")
                gnn_first_period = load_best_model_into(gnn_first_period, path_first, device)

            maybe_freeze(gnn_first_period, freeze=True)

            print(f"[Phase 2] Training GNN:n_classes={args.n_classes}, layers={args.num_layers_second}, J={args.J_second}, num_features={args.num_features_second}")
            if os.path.exists(path_second):
                print(f"[Stage 2] Detect the existing model file and load it directly: {path_second}")
                gnn_second_period = torch.load(path_second, map_location=device)
                gnn_second_period = gnn_second_period.to(device)
            else:
                if args.generative_model == 'SBM_multiclass':
                    gnn_second_period = GNN_multiclass_second_period(
                        args.num_features_second, args.num_layers_second, args.J_second + 3, n_classes=args.n_classes
                    )
                if torch.cuda.is_available():
                    gnn_second_period = gnn_second_period.to(device)

                loss_list, acc_list = train_second_period_with_early_stopping(
                    gnn_first_period, gnn_second_period, train_loader, val_loader,
                    args.n_classes, args,
                    epochs=100, patience=5, save_path=path_second, filename=filename_second
                )
                print(f"[Phase 2] Save the model to {path_second}")

        print("[Test phase] starts...")

        # 模型加载到当前 device
        gnn_first_period = torch.load(path_first, map_location=device)
        gnn_first_period = gnn_first_period.to(device)
        gnn_second_period = torch.load(path_second, map_location=device)
        gnn_second_period = gnn_second_period.to(device)

        print("[Test phase] starts...")
        class_sizes_dict = {
            2: [
                [500, 500],  # 平衡
                [600, 400],  # 轻度不平衡
                [700, 300],  # 中度不平衡
                [800, 200],

            ],

            4: [
                [250, 250, 250, 250],  # 平衡
                [300, 250, 250, 200],  # 轻度不平衡
                [400, 300, 200, 100],  # 中度不平衡
                [700, 100, 100, 100]  # 极端不平衡
            ],

            8: [
                [125, 125, 125, 125, 125, 125, 125, 125],  # 平衡
                [150, 125, 125, 125, 125, 125, 125, 100],  # 轻度不平衡
                [200, 180, 160, 140, 120, 100, 80, 20],  # 中度不平衡
                [650, 50, 50, 50, 50, 50, 50, 50]  # 极端不平衡
            ],
        }

        total_ab_dict = {
            2: [5, 10, 15],

            4: [15, 20, 25],

            8: [25, 30, 35]
        }

        class_sizes_list = class_sizes_dict[args.n_classes]
        total_ab_list = total_ab_dict[args.n_classes]

        snr_list = [0.25, 0.5, 0.75, 1, 1.5]

        N = 1000
        logN_div_N = np.log(N) / N

        task_args = []
        for total_ab in total_ab_list:
            for class_sizes in class_sizes_list:
                for snr in snr_list:
                    task_args.append((gnn_first_period, gnn_second_period, class_sizes, snr, gen, logN_div_N, total_ab))

        print("[Test Phase] Parallel execution...")

        with multiprocessing.Pool(processes=os.cpu_count() // 2) as pool:  # nodes 可调节，比如 4~8

            results = pool.map(test_first_period_wrapper, task_args)

        # 结果根目录，可自行命名
        BASE_DIR = "results"
        sbm_dir = os.path.join(BASE_DIR, "sbm")
        dcbm_dir = os.path.join(BASE_DIR, "dcbm")
        os.makedirs(sbm_dir, exist_ok=True)
        os.makedirs(dcbm_dir, exist_ok=True)

        # 逐任务写入：每个 total_ab 一个文件，但分开写到 sbm/ 与 dcbm/ 目录
        for ((rows_sbm, rows_dcbm), task) in zip(results, task_args):
            # 解包
            (row_acc_sbm, row_ari_sbm, row_nmi_sbm) = rows_sbm
            (row_acc_dcbm, row_ari_dcbm, row_nmi_dcbm) = rows_dcbm
            (gnn_fp, gnn_fp_second, class_sizes, snr, gen, logN_div_N, total_ab) = task

            # 文件名按 total_ab 区分
            sbm_file = os.path.join(sbm_dir, f"summary_total{total_ab}.xlsx")
            dcbm_file = os.path.join(dcbm_dir, f"summary_total{total_ab}.xlsx")

            # 附加一些上下文信息
            extra = {"class_sizes": class_sizes, "snr": snr, "total_ab": total_ab}

            # 分别写入
            append_rows_to_excel(row_acc_sbm, row_ari_sbm, row_nmi_sbm, filename=sbm_file, extra_info=extra)
            append_rows_to_excel(row_acc_dcbm, row_ari_dcbm, row_nmi_dcbm, filename=dcbm_file, extra_info=extra)

    except Exception as e:
        import traceback
        traceback.print_exc()
