import argparse
import multiprocessing
import os
import sys
import time

import pandas as pd
import torch
import torch.nn as nn
from joblib import Parallel, delayed
from scipy.io import loadmat
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader

from data_generator_dcsbm import Generator, simple_collate_fn
from load import get_gnn_inputs
from load_local_refinement import get_gnn_inputs_local_refinement
from losses import compute_loss_multiclass, compute_acc_ari_nmi, \
    gnn_compute_acc_ari_nmi_multiclass
from losses import from_scores_to_labels_multiclass_batch
from models import GNN_multiclass
from spectral_clustering import spectral_clustering_adj, local_refinement_by_neighbors_multi
from train_first_period import train_first_period_with_early_stopping


def setup_logger(prefix="main_gnn"):
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    pid = os.getpid()
    log_filename = f"{prefix}_{timestamp}_{pid}.log"
    logfile = open(log_filename, "w", buffering=1)  # 行缓冲
    sys.stdout = logfile
    sys.stderr = logfile
    print(f"[Logger initialized] Logging to: {log_filename}")


parser = argparse.ArgumentParser()

###############################################################################
#                             General Settings                                #
#                          提前配置参数，方便后面使用                              #
###############################################################################

parser.add_argument('--num_examples_train', nargs='?', const=1, type=int,
                    default=int(6000))
parser.add_argument('--num_examples_test', nargs='?', const=1, type=int,
                    default=int(100))
parser.add_argument('--num_examples_val', nargs='?', const=1, type=int,
                    default=int(1000))
parser.add_argument('--edge_density', nargs='?', const=1, type=float,
                    default=0.2)
parser.add_argument('--p_SBM', nargs='?', const=1, type=float,
                    default=0.2)
parser.add_argument('--q_SBM', nargs='?', const=1, type=float,
                    default=0.1)
parser.add_argument('--class_sizes', type=int, nargs='+', default=[100, 1000],
                    help='List of class sizes for imbalanced SBM')
parser.add_argument('--random_noise', action='store_true')
parser.add_argument('--noise', nargs='?', const=1, type=float, default=2)
parser.add_argument('--noise_model', nargs='?', const=1, type=int, default=2)
#########################
#parser.add_argument('--generative_model', nargs='?', const=1, type=str,
#                    default='ErdosRenyi')
parser.add_argument('--generative_model', nargs='?', const=1, type=str,
                    default='SBM_multiclass')
parser.add_argument('--batch_size', nargs='?', const=1, type=int, default= 16)
parser.add_argument('--mode', nargs='?', const=1, type=str, default='test')
default_path = os.path.join(os.path.dirname(os.path.abspath(__file__)))
parser.add_argument('--mode_isbalanced', nargs='?', const=1, type=str, default='imbalanced')
parser.add_argument('--path_gnn', nargs='?', const=1, type=str, default=default_path)
parser.add_argument('--path_local_refinement', nargs='?', const=1, type=str, default='')

parser.add_argument('--filename_existing_gnn', nargs='?', const=1, type=str, default='')
parser.add_argument('--filename_existing_gnn_local_refinement', nargs='?', const=1, type=str, default='')

parser.add_argument('--print_freq', nargs='?', const=1, type=int, default=1)
parser.add_argument('--test_freq', nargs='?', const=1, type=int, default=500)
parser.add_argument('--save_freq', nargs='?', const=1, type=int, default=2000)
parser.add_argument('--clip_grad_norm', nargs='?', const=1, type=float,
                    default=10.0)
parser.add_argument('--freeze_bn', dest='eval_vs_train', action='store_true')
parser.set_defaults(eval_vs_train=True)

###############################################################################
#                                 GNN Settings                                #
###############################################################################

parser.add_argument('--num_features', nargs='?', const=1, type=int,
                    default=16)
parser.add_argument('--num_layers', nargs='?', const=1, type=int,
                    default=30)
parser.add_argument('--J', nargs='?', const=1, type=int, default= 2)

parser.add_argument('--n_classes', nargs='?', const=1, type=int,
                    default=2)

###############################################################################
#                                 GNN second period                            #
###############################################################################
parser.add_argument('--num_features_second', nargs='?', const=1, type=int,
                    default=8)
parser.add_argument('--num_layers_second', nargs='?', const=1, type=int,
                    default=3)
parser.add_argument('--J_second', nargs='?', const=1, type=int, default= 1)

parser.add_argument('--N_train', nargs='?', const=1, type=int, default=1000)
parser.add_argument('--N_test', nargs='?', const=1, type=int, default=1000)
parser.add_argument('--N_val', nargs='?', const=1, type=int, default=1000)

parser.add_argument('--lr', nargs='?', const=1, type=float, default=4e-3)

args = parser.parse_args()

if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
    dtype_l = torch.cuda.LongTensor
    # torch.cuda.manual_seed(0)
else:
    dtype = torch.FloatTensor
    dtype_l = torch.LongTensor
    # torch.manual_seed(1)

batch_size = args.batch_size
criterion = nn.CrossEntropyLoss()
template1 = '{:<10} {:<10} {:<10} {:<15} {:<10} {:<10} {:<10}'
template2 = '{:<10} {:<10.5f} {:<10.5f} {:<15} {:<10} {:<10} {:<10.3f} \n'
template3 = '{:<10} {:<10} {:<10} '
template4 = '{:<10} {:<10.5f} {:<10.5f} \n'

template_header = '{:<6} {:<10} {:<10} {:<13} {:<10} {:<8} {:<10} {:<10} {:<20}'
template_row    = '{:<6} {:<10.5f} {:<10.5f} {:<13} {:<10} {:<8} {:<10.3f} {:<10.4f} {:<20}'



class SBMDataset(torch.utils.data.Dataset):
    def __init__(self, npz_file_list):
        self.files = npz_file_list

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        data = np.load(self.files[idx])
        adj = csr_matrix((data['adj_data'], data['adj_indices'], data['adj_indptr']), shape=tuple(data['adj_shape']))
        labels = data['labels']
        return {'adj': adj, 'labels': labels, 'num_nodes': adj.shape[0]}


# ✅ 正确的 imports
from scipy.sparse import csr_matrix
from spectral_clustering import local_refinement_by_neighbors
import numpy as np
from scipy.sparse import issparse
from scipy.sparse.linalg import eigsh as sparse_eigsh
from numpy.linalg import eigh as dense_eigh



def extract_spectral_features(
    adj,
    k: int,
    *,
    row_norm: bool = True,
    use_abs: bool = True,   # True: 按 |λ| 选前 k；False: 按最大正特征值选前 k
):
    """
    对邻接矩阵 A 做谱分解，返回前 k 个特征向量 (N, k) 作为特征。

    参数
    ----
    adj : (N,N) 稀疏或稠密邻接矩阵（应为对称实矩阵）
    k   : 取的特征向量个数
    row_norm : 是否对每行做 L2 归一化（默认 True，常更稳）
    use_abs  : 是否按 |λ| 选取（默认 True）。若 False，则取最大的正特征值方向

    返回
    ----
    U : np.ndarray, 形状 (N, k)，A 的前 k 个（按 |λ| 或按最大正值）特征向量
    """
    if issparse(adj):
        adj = adj.astype(np.float64)  # ✅ 修复 int 类型报错
        # 稀疏：直接用 ARPACK；which="LM" -> largest magnitude
        k_eff = max(1, min(k, adj.shape[0]-2))  # ARPACK 要求 k < N
        if use_abs:
            w, U = sparse_eigsh(adj, k=k_eff, which="LM")  # |λ| 最大
        else:
            w, U = sparse_eigsh(adj, k=k_eff, which="LA")  # 最大代数值（偏正端）
        # 为了稳定，按选择准则再排一下列顺序
        order = np.argsort(np.abs(w) if use_abs else w)[::-1]
        U = U[:, order[:k]]

    else:
        A = np.asarray(adj, dtype=np.float64)
        # 稠密：先全谱（升序），再手动筛选
        w_all, V_all = dense_eigh(A)  # 升序
        if use_abs:
            idx = np.argsort(np.abs(w_all))[-k:]          # 取 |λ| 最大的 k
            idx = idx[np.argsort(np.abs(w_all[idx]))[::-1]]
        else:
            idx = np.argsort(w_all)[-k:]                  # 取最大的 k 个（偏正端）
            idx = idx[::-1]
        U = V_all[:, idx]

    # 列符号对齐（避免跑不同次整体翻号）
    col_mean = U.mean(axis=0, keepdims=True)
    U *= np.where(col_mean >= 0, 1.0, -1.0)

    # 可选行归一化（把每个节点向量长度拉到 1 附近，消掉幅值差异）
    if row_norm:
        U = U / (np.linalg.norm(U, axis=1, keepdims=True) + 1e-12)

    return U


import joblib

os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")

def _load_and_feat(path, k):
    data = np.load(path)
    adj = csr_matrix((data['adj_data'], data['adj_indices'], data['adj_indptr']),
                     shape=tuple(data['adj_shape']))
    labels = np.asarray(data['labels']).ravel().astype(int)
    feats = extract_spectral_features(adj, k)   # ⬅️ 改成稀疏 eigsh 的版本
    return feats, labels

def train_logistic_regression(gen, k, save_path="./models/node_info_added_lr_model.pkl", n_jobs=os.cpu_count()//2):
    # 如果模型已存在 → 直接加载
    if os.path.exists(save_path):
        print(f"检测到已存在的模型文件 {save_path}，直接加载。")
        return joblib.load(save_path)

    # 并行提取特征
    results = Parallel(n_jobs=n_jobs, prefer="processes")(
        delayed(_load_and_feat)(fp, k) for fp in gen.data_train
    )
    all_features, all_labels = zip(*results)

    X_train = np.vstack(all_features).astype(np.float32, copy=False)
    y_train = np.concatenate(all_labels)

    # 训练逻辑回归模型
    lr_model = LogisticRegression(
        multi_class='multinomial',
        solver='saga',  # 支持大规模 & 稀疏数据
        max_iter=1000,
        random_state=42
    )

    lr_model.fit(X_train, y_train)

    # 保存模型
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    joblib.dump(lr_model, save_path)
    print(f"模型已保存到 {save_path}")

    return lr_model


def evaluate_using_logitic(W, true_labels ,lr_model, k):
    """
    在测试集上评估模型
    :param W: The graph we want to test
    :param true_labels: The true labels
    :param lr_model: The well-trained logistic regression model
    :param k: The number of communities
    :return: The average accuracy
    """
    if W.shape[0] == 1:
        W = W.squeeze(0)

    # 提取谱特征
    features = extract_spectral_features(W, k)

    # 预测标签
    pred_labels = lr_model.predict(features)

    # 计算准确率
    acc_logistic, logistic_best_matched_pred, ari_logistic, nmi_logistic = compute_acc_ari_nmi(pred_labels, true_labels, k)
    # best_acc, best_pred, ari, nmi

    # Local refinement
    logistic_best_matched_pred_refined = local_refinement_by_neighbors(W, logistic_best_matched_pred, k)
    acc_logistic_refined, best_pred, ari_logistic_refined, nmi_logistic_refined = compute_acc_ari_nmi(logistic_best_matched_pred_refined, true_labels, k)

    return acc_logistic, ari_logistic, nmi_logistic, acc_logistic_refined, ari_logistic_refined, nmi_logistic_refined

def get_available_device():
    for i in range(torch.cuda.device_count()):
        try:
            # 尝试分配临时张量测试显存
            torch.cuda.set_device(i)
            torch.zeros(1).cuda()
            return torch.device(f"cuda:{i}")
        except RuntimeError:
            continue
    return torch.device("cpu")


device = get_available_device()

def run_multi_refinement(
    gnn_model,
    W_np,
    init_labels,
    true_labels,
    args,
    device,
    num_iters: int = 1
):
    current_labels = init_labels
    history = []

    for t in range(num_iters):
        # --- 构造输入 ---
        WW, x = get_gnn_inputs_local_refinement(
            W_np, args.J_second, current_labels, args.n_classes
        )
        WW, x = WW.to(device), x.to(device)

        # --- 前向传播 ---
        pred = gnn_model(
            WW.type(torch.float32),
            x.type(torch.float32)
        )

        # --- 计算指标 ---
        acc, best_pred, ari, nmi = gnn_compute_acc_ari_nmi_multiclass(
            pred, true_labels, args.n_classes
        )

        # 更新 current_labels 为下一轮输入
        current_labels = best_pred

        history.append({
            "iter": t + 1,
            "pred_label": best_pred,
            "pred":pred,
            "acc": acc,
            "ari": ari,
            "nmi": nmi
        })

    return history

# ------------------------------------------------------------
# 封装：一次性跑完 total_iters 次 local refinement
# ------------------------------------------------------------
def run_refinement_chain(
    gnn_second_period,
    W_np,
    init_labels,
    true_labels,
    args,
    device,
    total_iters: int = 5,
    verbose: bool = True,
):
    """
    用现有的 run_multi_refinement 一次性跑完 total_iters 次 refinement。
    返回:
        {
          "hist": [dict, ...],   # 每一轮 refinement 的记录
          "first_iter": dict,    # 第一次 refinement 的记录 (兼容 second_* 变量)
          "final": dict          # 最后一轮 refinement 的记录
        }
    """
    # 跑满 total_iters 次
    hist = run_multi_refinement(
        gnn_second_period,
        W_np,
        init_labels,
        true_labels,
        args,
        device,
        num_iters=total_iters
    )

    # 输出每一轮指标
    if verbose:
        for i, r in enumerate(hist, start=1):
            print(f"[Iter {i}] acc={r['acc']:.4f}, ari={r['ari']:.4f}, nmi={r['nmi']:.4f}")

    out = {
        "hist": hist,
        "first_iter": hist[0] if len(hist) >= 1 else None,
        "final": hist[-1] if len(hist) >= 1 else None,
    }
    return out


# ------------------------------------------------------------
# 替换版：test_single_first_period
#   - 使用 run_refinement_chain 一次性完成多步 refinement
#   - 兼容你原先的变量名与日志显示
# ------------------------------------------------------------
def test_single_first_period(gnn_first_period, gnn_second_period, A, labels, n_classes, args, iter):
    """
    A: numpy.ndarray (N,N)  # 建议确保传入就是 numpy 2D
    labels: numpy.ndarray (N,) 或 (N,1)，取值任意（如 {1,2}、{3,5,7}）
    """
    start = time.time()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # === 0) Uniform input shapes ===
    labels = np.asarray(labels).squeeze()                 # -> (N,)
    A_np   = np.asarray(A)                                # -> (N,N)
    assert A_np.ndim == 2 and A_np.shape[0] == A_np.shape[1], f"A shape bad: {A_np.shape}"
    N = A_np.shape[0]

    # === 1) Labels do universal mapping to 0..K-1 ===
    uniq = np.unique(labels)
    remap = {v: i for i, v in enumerate(uniq)}
    labels_mapped = np.vectorize(remap.get)(labels).astype(np.int64)
    true_labels = torch.tensor(labels_mapped, dtype=torch.long, device=device).unsqueeze(0)  # (1,N)

    # === 2) Spectral & logistic baseline (with numpy 2D adjacency) ===
    res_spectral = spectral_clustering_adj(
        A_np, n_classes, true_labels,  # If you want 1D internally, you can pass it on true_labels.squeeze(0).cpu().numpy()
        normalized=True,
        run_all=True,
        random_state=0
    )

    # === 3) GNN 输入 ===
    W_np = np.expand_dims(A_np, 0)      # shape (1, N, N)
    WW, x = get_gnn_inputs(W_np, args.J)
    WW, x = WW.to(device), x.to(device)

    gnn_first_period.train()
    gnn_second_period.train()

    with torch.no_grad():
        # --- 阶段一 ---
        pred_single_first = gnn_first_period(WW.type(torch.float32), x.type(torch.float32))
        start_x_label = from_scores_to_labels_multiclass_batch(
            pred_single_first.detach().cpu().numpy()
        )
        print("Multi-step GNN_second_period开始")
        # === Multi-step refinement（一次性） ===
        ref = run_refinement_chain(
            gnn_second_period=gnn_second_period,
            W_np=W_np,
            init_labels=start_x_label,
            true_labels=true_labels,
            args=args,
            device=device,
            total_iters=1,
            verbose=True
        )

        # 兼容你原来的“第二层（第一次 refinement）命名”
        if ref["first_iter"] is not None:
            second_pred_label = ref["first_iter"]["pred_label"]
            second_pred       = ref["first_iter"]["pred"]
            second_acc        = ref["first_iter"]["acc"]
            second_ari        = ref["first_iter"]["ari"]
            second_nmi        = ref["first_iter"]["nmi"]
            # print(f"[Iter 1] acc={second_acc:.4f}, ari={second_ari:.4f}, nmi={second_nmi:.4f}")

        # 最终（第 total_iters 次）结果
        final_pred = ref["final"]["pred"]
        final_acc  = ref["final"]["acc"]
        final_ari  = ref["final"]["ari"]
        final_nmi  = ref["final"]["nmi"]
        print("[Final] "
              f"acc={final_acc:.4f}, ari={final_ari:.4f}, nmi={final_nmi:.4f}")

    # === 4) 邻接矩阵特征向量（numpy） ===
    W_for_eig = (A_np + A_np.T) / 2.0
    eigvals_W, eigvecs_W = np.linalg.eigh(W_for_eig)
    idx = np.argsort(eigvals_W)[-n_classes:][::-1]
    adjacency_eigvecs = eigvecs_W[:, idx]
    adjacency_eigvecs /= (np.linalg.norm(adjacency_eigvecs, axis=0, keepdims=True) + 1e-12)

    # === 5) GNN 倒数第二层特征（按你的写法：列归一化） ===
    penultimate_features = gnn_first_period.get_penultimate_output().detach().cpu().numpy().squeeze(0)  # (N, D)
    penultimate_features /= (np.linalg.norm(penultimate_features, axis=0, keepdims=True) + 1e-12)

    penultimate_features_second = gnn_second_period.get_penultimate_output().detach().cpu().numpy().squeeze(0)  # (N, D)
    penultimate_features_second /= (np.linalg.norm(penultimate_features_second, axis=0, keepdims=True) + 1e-12)

    # === 6) GNN loss & acc ===
    loss_test_first = compute_loss_multiclass(pred_single_first, true_labels, n_classes)
    acc_gnn_first, best_matched_pred, ari_gnn_first, nmi_gnn_first = \
        gnn_compute_acc_ari_nmi_multiclass(pred_single_first, true_labels, n_classes)
    gnn_pred_label_first = best_matched_pred  # (N,)

    loss_test_second = compute_loss_multiclass(second_pred, true_labels, n_classes)
    acc_gnn_second, best_matched_pred, ari_gnn_second, nmi_gnn_second = \
        gnn_compute_acc_ari_nmi_multiclass(second_pred, true_labels, n_classes)
    gnn_pred_label_second = best_matched_pred  # (N,)

    # local refinement（A 用 numpy，labels 用 1D）
    gnn_refined = local_refinement_by_neighbors(A_np, gnn_pred_label_first, n_classes)
    acc_gnn_refined, best_matched_pred , ari_gnn_refined, nmi_gnn_refined = \
        compute_acc_ari_nmi(gnn_refined, true_labels, n_classes)
    gnn_refined_label = best_matched_pred
    print("Multi-step GNN_local_refinement begining")
    # 多跑 10 步 refinement，记录每步 acc/ARI/NMI 与变化数量
    res_multi = local_refinement_by_neighbors_multi(
        A=A_np,  # (N,N) numpy
        init_labels=start_x_label,  # (N,)
        num_classes=n_classes,
        true_labels=true_labels,  # 你的 (1,N) torch.LongTensor 或 (N,) numpy
        num_iters=1,
        alpha=1e-6,
        random_state=0,
        tol=0,
        verbose=True,
        return_history=True
    )

    elapsed = time.time() - start
    loss_value = float(loss_test_first.detach().cpu().numpy())
    print(f"[Iter {iter}] Loss={loss_value:.4f}, Acc={acc_gnn_first:.4f}, Time={elapsed:.2f}s")

    # === 7) Save Excel（Adapt according to the list dimension） ===
    data = {
        'True_Label': true_labels.squeeze(0).cpu().numpy(),          # (N,)
        'Pred_Label_First': gnn_pred_label_first.reshape(-1),        # (N,)
        'Pred_Label_Second': gnn_pred_label_second.reshape(-1),
        'Pred_Label_Refined': gnn_refined_label.reshape(-1),
        'Loss_First': [loss_value] * N,
        'Acc_First': [float(acc_gnn_first)] * N,
    }

    # 第一阶段 penultimate_features
    Dp1 = penultimate_features.shape[1]
    for i in range(Dp1):
        data[f'pen_GNN_Feature_first{i + 1}'] = penultimate_features[:, i]

    # 第二阶段 penultimate_features_second
    Dp2 = penultimate_features_second.shape[1]
    for i in range(Dp2):
        data[f'pen_GNN_Feature_second{i + 1}'] = penultimate_features_second[:, i]

    # 邻接矩阵特征向量
    for i in range(adjacency_eigvecs.shape[1]):
        data[f'Adj_EigVecs_Top{i + 1}'] = adjacency_eigvecs[:, i]

    df = pd.DataFrame(data)
    root_folder = "penultimate_GNN_Feature"
    subfolder_name = f"penultimate_GNN_Feature_nclasses_{n_classes}"
    output_filename = f"first_gnn_fromA_iter{iter}_j={args.J}_nlyr={args.num_layers}.xlsx"
    output_path = os.path.join(root_folder, subfolder_name, output_filename)
    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    df.to_excel(output_path, sheet_name=f'Iteration_{iter}', index=False)

    # === 8) 汇总指标 ===
    acc_sc_n, ari_sc_n, nmi_sc_n = res_spectral['normalized']['sc']
    acc_rf_n, ari_rf_n, nmi_rf_n = res_spectral['normalized']['refined']
    acc_sc_u, ari_sc_u, nmi_sc_u = res_spectral['unnormalized']['sc']
    acc_rf_u, ari_rf_u, nmi_rf_u = res_spectral['unnormalized']['refined']
    acc_sc_a, ari_sc_a, nmi_sc_a = res_spectral['adjacency']['sc']
    acc_rf_a, ari_rf_a, nmi_rf_a = res_spectral['adjacency']['refined']


    metrics = {
        "gnn_first": {"acc": float(acc_gnn_first), "ari": float(ari_gnn_first), "nmi": float(nmi_gnn_first)},
        "gnn_second": {"acc": float(acc_gnn_second), "ari": float(ari_gnn_second), "nmi": float(nmi_gnn_second)},
        "gnn_refined": {"acc": float(acc_gnn_refined), "ari": float(ari_gnn_refined), "nmi": float(nmi_gnn_refined)},
        "spectral_normalized": {"acc": float(acc_sc_n), "ari": float(ari_sc_n), "nmi": float(nmi_sc_n)},
        "spectral_normalized_refined": {"acc": float(acc_rf_n), "ari": float(ari_rf_n), "nmi": float(nmi_rf_n)},
        "spectral_unnormalized": {"acc": float(acc_sc_u), "ari": float(ari_sc_u), "nmi": float(nmi_sc_u)},
        "spectral_unnormalized_refined": {"acc": float(acc_rf_u), "ari": float(ari_rf_u), "nmi": float(nmi_rf_u)},
        "spectral_adjacency": {"acc": float(acc_sc_a), "ari": float(ari_sc_a), "nmi": float(nmi_sc_a)},
        "spectral_adjacency_refined": {"acc": float(acc_rf_a), "ari": float(ari_rf_a), "nmi": float(nmi_rf_a)},
    }

    return loss_value, metrics


if __name__ == '__main__':
    try:
        setup_logger("run_gnn")
        gen = Generator()
        gen.N_train = args.N_train
        gen.N_test = args.N_test
        gen.N_val = args.N_val

        gen.edge_density = args.edge_density
        gen.p_SBM = args.p_SBM
        gen.q_SBM = args.q_SBM

        gen.random_noise = args.random_noise
        gen.noise = args.noise
        gen.noise_model = args.noise_model
        gen.generative_model = args.generative_model
        gen.n_classes = args.n_classes

        gen.num_examples_train = args.num_examples_train
        gen.num_examples_test = args.num_examples_test
        gen.num_examples_val = args.num_examples_val

        # 1. 创建总模型文件夹
        root_model_dir = "model_GNN"
        os.makedirs(root_model_dir, exist_ok=True)

        # 2. 创建子目录（按 n_classes 分类）
        folder_name = f"GNN_model_first_classes{args.n_classes}"
        full_save_dir = os.path.join(root_model_dir, folder_name)
        os.makedirs(full_save_dir, exist_ok=True)

        # 3. 构造保存路径
        filename_first = (
            f'gnn_J{args.J}_lyr{args.num_layers}_classes{args.n_classes}_numfeatures{args.num_features}'
        )
        path_first = os.path.join(full_save_dir, filename_first)

        filename_second = (
            f'local+refin_gnn_J{args.J_second}_lyr{args.num_layers_second}_classes{args.n_classes}_numfeatures{args.num_features_second}'
        )

        path_second = os.path.join(full_save_dir, filename_second)


        if args.mode == "train":
            ################################################################################################################
            # Here is the train period, prepare the dataloader we need to train
            gen.prepare_data()

            # 1. 准备并行线程数
            num_workers = min(4, multiprocessing.cpu_count() - 1)
            print("num_workers", num_workers)

            # 2. 创建 Dataset 实例
            train_dataset = SBMDataset(gen.data_train)
            val_dataset = SBMDataset(gen.data_val)
            test_dataset = SBMDataset(gen.data_test)

            # 3. 创建对应的 DataLoader
            train_loader = DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=num_workers,
                collate_fn=simple_collate_fn
            )

            val_loader = DataLoader(
                val_dataset,
                batch_size=args.batch_size,
                shuffle=False,  # ❗验证集不需要打乱顺序
                num_workers=num_workers,
                collate_fn=simple_collate_fn
            )

            test_loader = DataLoader(
                test_dataset,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=num_workers,
                collate_fn=simple_collate_fn

            )

            print(
                f"[Phase 1] Train the GNN：n_classes={args.n_classes}, layers={args.num_layers}, J = {args.J}")
            # 初始化并训练第一阶段
            if args.generative_model == 'SBM_multiclass':
                gnn_first_period = GNN_multiclass(args.num_features, args.num_layers, args.J + 3,
                                                  n_classes=args.n_classes)
            if torch.cuda.is_available():
                gnn_first_period = gnn_first_period.to(device)

            loss_list, acc_list = train_first_period_with_early_stopping(
                gnn_first_period, train_loader, val_loader, args.n_classes, args,
                epochs=100, patience=10, save_path = filename_first)


            print(f'Saving first-period GNN to {path_first}')
            torch.save(gnn_first_period.cpu(), path_first)

            if torch.cuda.is_available():
                gnn_first_period = gnn_first_period.to('cuda')

        print("[Test phase] starts...")
        print("Training logistic regression models...")
        # lr_model = train_logistic_regression(gen, args.n_classes)

        gnn_first_period = torch.load(path_first, map_location=torch.device('cpu'))

        gnn_second_period = torch.load(path_second, map_location=torch.device('cpu'))

        print("[Test phase] starts...")

        mat = loadmat(r"politicalblog_CC.mat")

        A_real = mat["A"]
        labels_real = mat["label"].squeeze()

        A = np.array(A_real, dtype=float)
        loss, metrics = test_single_first_period(
            gnn_first_period, gnn_second_period , A_real, labels_real, n_classes=args.n_classes, args=args, iter=0
        )

        df_metrics = pd.DataFrame(metrics).T
        print(df_metrics)

    except Exception as e:

        import traceback

        traceback.print_exc()