import argparse
import multiprocessing
import os
import random
import sys
import time
from multiprocessing import Pool
from pathlib import Path

import pandas as pd
import torch
import torch.nn as nn
from joblib import Parallel, delayed
from scipy.sparse.csgraph import laplacian
from sklearn.linear_model import LogisticRegression
from torch import Tensor
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from controlsnr import find_a_given_snr
from data_generator import Generator, simple_collate_fn
from load import get_gnn_inputs
from losses import compute_loss_multiclass, compute_acc_ari_nmi, \
    gnn_compute_acc_ari_nmi_multiclass
from models import GNN_multiclass
from spectral_clustering import spectral_clustering_adj, to_one_hot
from train_first_period import train_first_period_with_early_stopping
from scipy.io import loadmat
from losses import from_scores_to_labels_multiclass_batch
from load_local_refinement import get_gnn_inputs_local_refinement

def setup_logger(prefix="main_gnn"):
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    pid = os.getpid()
    log_filename = f"{prefix}_{timestamp}_{pid}.log"
    logfile = open(log_filename, "w", buffering=1)  # 行缓冲
    sys.stdout = logfile
    sys.stderr = logfile
    print(f"[Logger initialized] Logging to: {log_filename}")


parser = argparse.ArgumentParser()

###############################################################################
#                             General Settings                                #
#                          提前配置参数，方便后面使用                              #
###############################################################################

parser.add_argument('--num_examples_train', nargs='?', const=1, type=int,
                    default=int(6000))
parser.add_argument('--num_examples_test', nargs='?', const=1, type=int,
                    default=int(100))
parser.add_argument('--num_examples_val', nargs='?', const=1, type=int,
                    default=int(1000))
parser.add_argument('--edge_density', nargs='?', const=1, type=float,
                    default=0.2)
parser.add_argument('--p_SBM', nargs='?', const=1, type=float,
                    default=0.2)
parser.add_argument('--q_SBM', nargs='?', const=1, type=float,
                    default=0.1)
parser.add_argument('--class_sizes', type=int, nargs='+', default=[100, 1000],
                    help='List of class sizes for imbalanced SBM')
parser.add_argument('--random_noise', action='store_true')
parser.add_argument('--noise', nargs='?', const=1, type=float, default=2)
parser.add_argument('--noise_model', nargs='?', const=1, type=int, default=2)
#########################
#parser.add_argument('--generative_model', nargs='?', const=1, type=str,
#                    default='ErdosRenyi')
parser.add_argument('--generative_model', nargs='?', const=1, type=str,
                    default='SBM_multiclass')
parser.add_argument('--batch_size', nargs='?', const=1, type=int, default= 16)
parser.add_argument('--mode', nargs='?', const=1, type=str, default='test')
default_path = os.path.join(os.path.dirname(os.path.abspath(__file__)))
parser.add_argument('--mode_isbalanced', nargs='?', const=1, type=str, default='imbalanced')
parser.add_argument('--path_gnn', nargs='?', const=1, type=str, default=default_path)
parser.add_argument('--path_local_refinement', nargs='?', const=1, type=str, default='')

parser.add_argument('--filename_existing_gnn', nargs='?', const=1, type=str, default='')
parser.add_argument('--filename_existing_gnn_local_refinement', nargs='?', const=1, type=str, default='')

parser.add_argument('--print_freq', nargs='?', const=1, type=int, default=1)
parser.add_argument('--test_freq', nargs='?', const=1, type=int, default=500)
parser.add_argument('--save_freq', nargs='?', const=1, type=int, default=2000)
parser.add_argument('--clip_grad_norm', nargs='?', const=1, type=float,
                    default=10.0)
parser.add_argument('--freeze_bn', dest='eval_vs_train', action='store_true')
parser.set_defaults(eval_vs_train=True)

###############################################################################
#                                 GNN Settings                                #
###############################################################################

parser.add_argument('--num_features', nargs='?', const=1, type=int,
                    default=16)
parser.add_argument('--num_layers', nargs='?', const=1, type=int,
                    default=30)
parser.add_argument('--n_classes', nargs='?', const=1, type=int,
                    default=2)

###############################################################################
#                                 GNN second period                            #
###############################################################################
parser.add_argument('--num_features_second', nargs='?', const=1, type=int,
                    default=16)
parser.add_argument('--num_layers_second', nargs='?', const=1, type=int,
                    default=10)
parser.add_argument('--J_second', nargs='?', const=1, type=int, default= 2)

parser.add_argument('--J', nargs='?', const=1, type=int, default= 2)
parser.add_argument('--N_train', nargs='?', const=1, type=int, default=1000)
parser.add_argument('--N_test', nargs='?', const=1, type=int, default=1000)
parser.add_argument('--N_val', nargs='?', const=1, type=int, default=1000)

parser.add_argument('--lr', nargs='?', const=1, type=float, default=4e-3)

args = parser.parse_args()

if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
    dtype_l = torch.cuda.LongTensor
    # torch.cuda.manual_seed(0)
else:
    dtype = torch.FloatTensor
    dtype_l = torch.LongTensor
    # torch.manual_seed(1)

batch_size = args.batch_size
criterion = nn.CrossEntropyLoss()
template1 = '{:<10} {:<10} {:<10} {:<15} {:<10} {:<10} {:<10}'
template2 = '{:<10} {:<10.5f} {:<10.5f} {:<15} {:<10} {:<10} {:<10.3f} \n'
template3 = '{:<10} {:<10} {:<10} '
template4 = '{:<10} {:<10.5f} {:<10.5f} \n'

template_header = '{:<6} {:<10} {:<10} {:<13} {:<10} {:<8} {:<10} {:<10} {:<20}'
template_row    = '{:<6} {:<10.5f} {:<10.5f} {:<13} {:<10} {:<8} {:<10.3f} {:<10.4f} {:<20}'



class SBMDataset(torch.utils.data.Dataset):
    def __init__(self, npz_file_list):
        self.files = npz_file_list

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        data = np.load(self.files[idx])
        adj = csr_matrix((data['adj_data'], data['adj_indices'], data['adj_indptr']), shape=tuple(data['adj_shape']))
        labels = data['labels']
        return {'adj': adj, 'labels': labels, 'num_nodes': adj.shape[0]}


# ✅ 正确的 imports
from scipy.sparse import csr_matrix
from spectral_clustering import local_refinement_by_neighbors
import numpy as np
from scipy.sparse import issparse
from scipy.sparse.linalg import eigsh as sparse_eigsh, eigsh
from numpy.linalg import eigh as dense_eigh
from scipy.linalg import eigh as scipy_eigh


# def extract_spectral_features(
#     adj,
#     k,
#     *,
#     row_norm=True,
#     add_relu=False,          # 是否把 ReLU(U) 一并拼上
#     degree_mode="norm",      # 'raw' | 'norm' | 'log_norm' | 'zscore' | 'sqrt_norm'
#     degree_weight=1.0        # λ：拼接时度特征的权重
# ):
#     """
#     返回特征:
#         若 add_relu=False:  [U, λ·d]              -> 形状 (N, k+1)
#         若 add_relu=True :  [U, ReLU(U), λ·d]     -> 形状 (N, 2k+1)
#     """
#     # ---------- 谱嵌入 ----------
#     if issparse(adj):
#         # 稀疏分支：保留稀疏，直接对稀疏拉普拉斯做 eigsh
#         A = adj.tocsr().astype(np.float64)
#         L = laplacian(A, normed=True).astype(np.float64)
#
#         n = L.shape[0]
#         k_eff = min(k, max(1, n - 2))  # ARPACK 要求 k < n，留余量更稳
#         # 取最靠近 0 的 k 个特征向量（对应最小特征值）
#         vals, U = eigsh(L, k=k_eff, sigma=0.0, which='LM')
#         # 按特征值升序
#         U = U[:, np.argsort(vals)]
#         N = n
#     else:
#         # 稠密分支：用 SciPy 的 eigh + subset_by_index
#         A = np.asarray(adj, dtype=np.float64)
#         N = A.shape[0]
#         L = laplacian(A, normed=True).astype(np.float64)
#         # 仅求前 k 个最小特征向量
#         _, U = scipy_eigh(L, subset_by_index=(0, k - 1), check_finite=False)
#
#     # 列符号对齐：每列均值为非负
#     col_mean = U.mean(axis=0, keepdims=True)
#     sign = np.where(col_mean >= 0, 1.0, -1.0)
#     U = U * sign
#
#     # 行归一化
#     if row_norm:
#         denom = np.linalg.norm(U, axis=1, keepdims=True) + 1e-12
#         U = U / denom
#
#     # 可选 ReLU
#     base_feat = np.hstack([U, np.maximum(U, 0.0)]) if add_relu else U  # (N, 2k) or (N, k)
#
#     # ---------- 度特征 ----------
#     if issparse(adj):
#         # 稀疏度：sum(axis=1) 返回 (N,1) 矩阵
#         d = np.asarray(adj.sum(axis=1)).ravel()
#     else:
#         d = A.sum(axis=1)
#
#     if degree_mode == "raw":
#         d_feat = d
#     elif degree_mode == "norm":
#         m = d.max()
#         d_feat = d / m if m > 0 else d
#     elif degree_mode == "log_norm":
#         d_log = np.log1p(d)
#         m = d_log.max()
#         d_feat = d_log / m if m > 0 else d_log
#     elif degree_mode == "sqrt_norm":
#         d_sqrt = np.sqrt(d)
#         m = d_sqrt.max()
#         d_feat = d_sqrt / m if m > 0 else d_sqrt
#     elif degree_mode == "zscore":
#         mu, std = d.mean(), d.std()
#         d_feat = (d - mu) / (std + 1e-12)
#     else:
#         raise ValueError("degree_mode 必须是 'raw' | 'norm' | 'log_norm' | 'zscore' | 'sqrt_norm' 之一")
#
#     d_feat = degree_weight * d_feat.reshape(N, 1)
#
#     # ---------- 拼接并返回 ----------
#     features = np.hstack([base_feat, d_feat])  # (N, k+1) 或 (N, 2k+1)
#     return features

def extract_spectral_features(
    adj,
    k: int,
    *,
    row_norm: bool = True,
    use_abs: bool = True,   # True: 按 |λ| 选前 k；False: 按最大正特征值选前 k
):
    """
    对邻接矩阵 A 做谱分解，返回前 k 个特征向量 (N, k) 作为特征。

    参数
    ----
    adj : (N,N) 稀疏或稠密邻接矩阵（应为对称实矩阵）
    k   : 取的特征向量个数
    row_norm : 是否对每行做 L2 归一化（默认 True，常更稳）
    use_abs  : 是否按 |λ| 选取（默认 True）。若 False，则取最大的正特征值方向

    返回
    ----
    U : np.ndarray, 形状 (N, k)，A 的前 k 个（按 |λ| 或按最大正值）特征向量
    """
    if issparse(adj):
        adj = adj.astype(np.float64)  # ✅ 修复 int 类型报错
        # 稀疏：直接用 ARPACK；which="LM" -> largest magnitude
        k_eff = max(1, min(k, adj.shape[0]-2))  # ARPACK 要求 k < N
        if use_abs:
            w, U = sparse_eigsh(adj, k=k_eff, which="LM")  # |λ| 最大
        else:
            w, U = sparse_eigsh(adj, k=k_eff, which="LA")  # 最大代数值（偏正端）
        # 为了稳定，按选择准则再排一下列顺序
        order = np.argsort(np.abs(w) if use_abs else w)[::-1]
        U = U[:, order[:k]]

    else:
        A = np.asarray(adj, dtype=np.float64)
        # 稠密：先全谱（升序），再手动筛选
        w_all, V_all = dense_eigh(A)  # 升序
        if use_abs:
            idx = np.argsort(np.abs(w_all))[-k:]          # 取 |λ| 最大的 k
            idx = idx[np.argsort(np.abs(w_all[idx]))[::-1]]
        else:
            idx = np.argsort(w_all)[-k:]                  # 取最大的 k 个（偏正端）
            idx = idx[::-1]
        U = V_all[:, idx]

    # 列符号对齐（避免跑不同次整体翻号）
    col_mean = U.mean(axis=0, keepdims=True)
    U *= np.where(col_mean >= 0, 1.0, -1.0)

    # 可选行归一化（把每个节点向量长度拉到 1 附近，消掉幅值差异）
    if row_norm:
        U = U / (np.linalg.norm(U, axis=1, keepdims=True) + 1e-12)

    return U


import joblib

os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")

def _load_and_feat(path, k):
    data = np.load(path)
    adj = csr_matrix((data['adj_data'], data['adj_indices'], data['adj_indptr']),
                     shape=tuple(data['adj_shape']))
    labels = np.asarray(data['labels']).ravel().astype(int)
    feats = extract_spectral_features(adj, k)   # ⬅️ 改成稀疏 eigsh 的版本
    return feats, labels

def train_logistic_regression(gen, k, save_path="./models/node_info_added_lr_model.pkl", n_jobs=os.cpu_count()//2):
    # 如果模型已存在 → 直接加载
    if os.path.exists(save_path):
        print(f"检测到已存在的模型文件 {save_path}，直接加载。")
        return joblib.load(save_path)

    # 并行提取特征
    results = Parallel(n_jobs=n_jobs, prefer="processes")(
        delayed(_load_and_feat)(fp, k) for fp in gen.data_train
    )
    all_features, all_labels = zip(*results)

    X_train = np.vstack(all_features).astype(np.float32, copy=False)
    y_train = np.concatenate(all_labels)

    # 训练逻辑回归模型
    lr_model = LogisticRegression(
        multi_class='multinomial',
        solver='saga',  # 支持大规模 & 稀疏数据
        max_iter=1000,
        random_state=42
    )

    lr_model.fit(X_train, y_train)

    # 保存模型
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    joblib.dump(lr_model, save_path)
    print(f"模型已保存到 {save_path}")

    return lr_model


def evaluate_using_logitic(W, true_labels ,lr_model, k):
    """
    在测试集上评估模型
    :param W: The graph we want to test
    :param true_labels: The true labels
    :param lr_model: The well-trained logistic regression model
    :param k: The number of communities
    :return: The average accuracy
    """
    if W.shape[0] == 1:
        W = W.squeeze(0)

    # 提取谱特征
    features = extract_spectral_features(W, k)

    # 预测标签
    pred_labels = lr_model.predict(features)

    # 计算准确率
    acc_logistic, logistic_best_matched_pred, ari_logistic, nmi_logistic = compute_acc_ari_nmi(pred_labels, true_labels, k)
    # best_acc, best_pred, ari, nmi

    # Local refinement
    logistic_best_matched_pred_refined = local_refinement_by_neighbors(W, logistic_best_matched_pred, k)
    acc_logistic_refined, best_pred, ari_logistic_refined, nmi_logistic_refined = compute_acc_ari_nmi(logistic_best_matched_pred_refined, true_labels, k)

    return acc_logistic, ari_logistic, nmi_logistic, acc_logistic_refined, ari_logistic_refined, nmi_logistic_refined

def get_available_device():
    for i in range(torch.cuda.device_count()):
        try:
            # 尝试分配临时张量测试显存
            torch.cuda.set_device(i)
            torch.zeros(1).cuda()
            return torch.device(f"cuda:{i}")
        except RuntimeError:
            continue
    return torch.device("cpu")


device = get_available_device()

# def test_single_first_period(
#     gnn_first_period, lr_model, A, labels, n_classes, args, iter
# ):
#     """
#     使用已有的邻接矩阵 A 和标签 labels 来测试 GNN。
#     A: numpy.ndarray (N,N)
#     labels: numpy.ndarray (N,) 或 (N,1)
#     """
#     start = time.time()
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     gnn_first_period.eval()
#
#     # === 1. 转成 torch 张量 ===
#     if labels.ndim > 1:
#         labels = labels.squeeze()
#
#     W = torch.tensor(A, dtype=torch.float32, device=device).unsqueeze(0)  # (1,N,N)
#     # W = A.squeeze(0).cpu().numpy() if isinstance(A, torch.Tensor) else A
#
#     labels = np.array(labels) - 1  # 把 1,2 → 0,1
#     true_labels = torch.tensor(labels, dtype=torch.long, device=device).unsqueeze(0)  # (1,N)
#     # === 2. 跑谱聚类 baseline ===
#
#     res_spectral = spectral_clustering_adj(
#         W, n_classes, true_labels,
#         normalized=True,
#         run_all=True,
#         random_state=0
#     )
#
#     # === 3. 跑 logistic baseline ===
#     acc_logistic, ari_logistic, nmi_logistic, acc_logistic_refined, ari_logistic_refined, nmi_logistic_refined \
#         = evaluate_using_logitic(W, true_labels, lr_model, n_classes)
#
#     # === 4. GNN 输入 ===
#     WW, x = get_gnn_inputs(W, args.J)
#     WW, x = WW.to(device), x.to(device)
#
#     with torch.no_grad():
#         pred_single_first = gnn_first_period(WW.type(torch.float32), x.type(torch.float32))
#
#     # === 5. 邻接矩阵特征向量 ===
#     W_for_eig = (A + A.T) / 2
#     eigvals_W, eigvecs_W = np.linalg.eigh(W_for_eig)
#     adjacency_eigvecs = eigvecs_W[:, np.argsort(eigvals_W)[-n_classes:][::-1]]
#     adjacency_eigvecs /= np.linalg.norm(adjacency_eigvecs, axis=0, keepdims=True)
#
#     # === 6. 提取 GNN penultimate 特征 ===
#     penultimate_features = gnn_first_period.get_penultimate_output().detach().cpu().numpy().squeeze(0)
#     penultimate_features /= np.linalg.norm(penultimate_features, axis=0, keepdims=True)
#
#     # === 7. GNN loss & acc ===
#     loss_test_first = compute_loss_multiclass(pred_single_first, true_labels, n_classes)
#     acc_gnn_first, best_matched_pred, ari_gnn_first, nmi_gnn_first = \
#         gnn_compute_acc_ari_nmi_multiclass(pred_single_first, true_labels, n_classes)
#     gnn_pred_label_first = best_matched_pred
#
#     # local refinement
#     gnn_refined = local_refinement_by_neighbors(A, gnn_pred_label_first, n_classes)
#     acc_gnn_refined, best_matched_pred, ari_gnn_refined, nmi_gnn_refined = \
#         compute_acc_ari_nmi(gnn_refined, true_labels, n_classes)
#
#     elapsed = time.time() - start
#     loss_value = float(loss_test_first.detach().cpu().numpy())
#
#     print(f"[Iter {iter}] Loss={loss_value:.4f}, Acc={acc_gnn_first:.4f}, Time={elapsed:.2f}s")
#
#     # === 8. 保存 Excel ===
#     N = true_labels.shape[1]
#     data = {
#         'True_Label': true_labels.squeeze(0).cpu().numpy(),
#         'Pred_Label_First': gnn_pred_label_first.reshape(-1),
#         'Loss_First': [loss_value] * N,
#         'Acc_First': [float(acc_gnn_first)] * N,
#     }
#
#     # penultimate 特征
#     for i in range(2 * n_classes):
#         data[f'penultimate_GNN_Feature{i + 1}'] = penultimate_features[:, i]
#
#     # 上游 eigvecs_top（来自谱分解函数 spectral_clustering_adj）
#     # eigvecs_top = res_spectral.get("eigvecs_top", np.zeros((N, n_classes)))  # 兼容处理
#     # for i in range(n_classes):
#     #     data[f'eigvecs_top{i + 1}'] = eigvecs_top[:, i]
#     # 邻接矩阵特征向量
#
#     for i in range(n_classes):
#         data[f'Adj_EigVecs_Top{i + 1}'] = adjacency_eigvecs[:, i]
#
#     df = pd.DataFrame(data)
#
#     root_folder = "penultimate_GNN_Feature"
#     subfolder_name = f"penultimate_GNN_Feature_nclasses_{n_classes}"
#     output_filename = (
#         f"first_gnn_fromA_iter{iter}_j={args.J}_nlyr={args.num_layers}.xlsx"
#     )
#     output_path = os.path.join(root_folder, subfolder_name, output_filename)
#     os.makedirs(os.path.dirname(output_path), exist_ok=True)
#
#     if iter < 10:
#         if iter == 0:
#             df.to_excel(output_path, sheet_name=f'Iteration_{iter}', index=False)
#         else:
#             with pd.ExcelWriter(output_path, mode='a', engine='openpyxl') as writer:
#                 df.to_excel(writer, sheet_name=f'Iteration_{iter}', index=False)
#
#     # === 9. 拆三种谱方法的结果 ===
#     acc_sc_n, ari_sc_n, nmi_sc_n = res_spectral['normalized']['sc']
#     acc_rf_n, ari_rf_n, nmi_rf_n = res_spectral['normalized']['refined']
#
#     acc_sc_u, ari_sc_u, nmi_sc_u = res_spectral['unnormalized']['sc']
#     acc_rf_u, ari_rf_u, nmi_rf_u = res_spectral['unnormalized']['refined']
#
#     acc_sc_a, ari_sc_a, nmi_sc_a = res_spectral['adjacency']['sc']
#     acc_rf_a, ari_rf_a, nmi_rf_a = res_spectral['adjacency']['refined']
#
#     metrics = {
#         "gnn": {
#             "acc": float(acc_gnn_first),
#             "ari": float(ari_gnn_first),
#             "nmi": float(nmi_gnn_first),
#         },
#         "gnn_refined": {
#             "acc": float(acc_gnn_refined),
#             "ari": float(ari_gnn_refined),
#             "nmi": float(nmi_gnn_refined),
#         },
#         "spectral_normalized": {"acc": float(acc_sc_n), "ari": float(ari_sc_n), "nmi": float(nmi_sc_n)},
#         "spectral_normalized_refined": {"acc": float(acc_rf_n), "ari": float(ari_rf_n), "nmi": float(nmi_rf_n)},
#         "spectral_unnormalized": {"acc": float(acc_sc_u), "ari": float(ari_sc_u), "nmi": float(nmi_sc_u)},
#         "spectral_unnormalized_refined": {"acc": float(acc_rf_u), "ari": float(ari_rf_u), "nmi": float(nmi_rf_u)},
#         "spectral_adjacency": {"acc": float(acc_sc_a), "ari": float(ari_sc_a), "nmi": float(nmi_sc_a)},
#         "spectral_adjacency_refined": {"acc": float(acc_rf_a), "ari": float(ari_rf_a), "nmi": float(nmi_rf_a)},
#         "logistic": {
#             "acc": float(acc_logistic),
#             "ari": float(ari_logistic),
#             "nmi": float(nmi_logistic),
#         },
#         "logistic_refined": {
#             "acc": float(acc_logistic_refined),
#             "ari": float(ari_logistic_refined),
#             "nmi": float(nmi_logistic_refined),
#         },
#     }
#
#     return loss_value, metrics
def test_single_first_period(gnn_first_period, gnn_second_period ,lr_model, A, labels, n_classes, args, iter):
    """
    A: numpy.ndarray (N,N)  # 建议确保传入就是 numpy 2D
    labels: numpy.ndarray (N,) 或 (N,1)，取值任意（如 {1,2}、{3,5,7}）
    """
    start = time.time()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # === 0) 统一输入形状 ===
    labels = np.asarray(labels).squeeze()                 # -> (N,)
    A_np   = np.asarray(A)                                # -> (N,N)
    assert A_np.ndim == 2 and A_np.shape[0] == A_np.shape[1], f"A shape bad: {A_np.shape}"
    N = A_np.shape[0]

    # === 1) 标签做通用映射到 0..K-1 ===
    uniq = np.unique(labels)
    remap = {v:i for i, v in enumerate(uniq)}
    labels_mapped = np.vectorize(remap.get)(labels).astype(np.int64)
    # 若你希望强制 n_classes 按 labels 决定，可： n_classes = max(n_classes, len(uniq))

    true_labels = torch.tensor(labels_mapped, dtype=torch.long, device=device).unsqueeze(0)  # (1,N)

    # === 2) 谱 & logistic 基线（用 numpy 2D 邻接） ===
    res_spectral = spectral_clustering_adj(
        A_np, n_classes, true_labels,  # 注意：如果该函数内部期望 1D labels，可传 true_labels.squeeze(0).cpu().numpy()
        normalized=True,
        run_all=True,
        random_state=0
    )

    acc_logistic, ari_logistic, nmi_logistic, acc_logistic_refined, ari_logistic_refined, nmi_logistic_refined = \
        evaluate_using_logitic(A_np, true_labels, lr_model, n_classes)

    # === 3) GNN 输入（这里选择 torch 版管道） ===
    # 如果 compute_operators 里用 torch.clone 而非 numpy.copy，则传 torch 的 (1,N,N)
    W_np = np.expand_dims(A_np, 0)  # shape (1, N, N)
    WW, x = get_gnn_inputs(W_np, args.J)  # 确保内部不再调用 numpy 的 .copy()；如有，改成 .clone()
    WW, x = WW.to(device), x.to(device)

    gnn_first_period.train()
    gnn_second_period.train()

    with torch.no_grad():
        pred_single_first = gnn_first_period(WW.type(torch.float32), x.type(torch.float32))
        start_x_label = from_scores_to_labels_multiclass_batch(pred_single_first)
        WW_second, x_second = get_gnn_inputs_local_refinement(W_np, args.J, start_x_label, args.n_classes)
        WW_second, x_second = WW_second.to(device), x_second.to(device)
        pred_single_second =  gnn_second_period(WW_second,x_second)

    # === 4) 邻接矩阵特征向量（numpy） ===
    W_for_eig = (A_np + A_np.T) / 2.0
    eigvals_W, eigvecs_W = np.linalg.eigh(W_for_eig)
    idx = np.argsort(eigvals_W)[-n_classes:][::-1]
    adjacency_eigvecs = eigvecs_W[:, idx]
    adjacency_eigvecs /= (np.linalg.norm(adjacency_eigvecs, axis=0, keepdims=True) + 1e-12)

    # === 5) GNN 倒数第二层特征（按你的写法：列归一化） ===
    penultimate_features = gnn_first_period.get_penultimate_output().detach().cpu().numpy().squeeze(0)  # (N, D)
    penultimate_features /= (np.linalg.norm(penultimate_features, axis=0, keepdims=True) + 1e-12)

    # === 6) GNN loss & acc ===
    # 若 compute_loss_multiclass 期望 1D labels，可传 true_labels.squeeze(0)
    loss_test_first = compute_loss_multiclass(pred_single_first, true_labels, n_classes)
    acc_gnn_first, best_matched_pred, ari_gnn_first, nmi_gnn_first = \
        gnn_compute_acc_ari_nmi_multiclass(pred_single_first, true_labels, n_classes)
    gnn_pred_label_first = best_matched_pred  # 形状确保为 (N,)

    loss_test_second = compute_loss_multiclass(pred_single_second, true_labels, n_classes)
    acc_gnn_second, best_matched_pred, ari_gnn_second, nmi_gnn_second = \
        gnn_compute_acc_ari_nmi_multiclass(pred_single_second, true_labels, n_classes)
    gnn_pred_label_second = best_matched_pred  # 形状确保为 (N,)

    # local refinement（A 用 numpy，labels 用 1D）
    gnn_refined = local_refinement_by_neighbors(A_np, gnn_pred_label_first, n_classes)
    acc_gnn_refined, _, ari_gnn_refined, nmi_gnn_refined = \
        compute_acc_ari_nmi(gnn_refined, true_labels, n_classes)

    elapsed = time.time() - start
    loss_value = float(loss_test_first.detach().cpu().numpy())
    print(f"[Iter {iter}] Loss={loss_value:.4f}, Acc={acc_gnn_first:.4f}, Time={elapsed:.2f}s")

    # === 7) 保存 Excel（按列名维度自适配） ===
    data = {
        'True_Label': true_labels.squeeze(0).cpu().numpy(),          # (N,)
        'Pred_Label_First': gnn_pred_label_first.reshape(-1),        # (N,)
        'Loss_First': [loss_value] * N,
        'Acc_First': [float(acc_gnn_first)] * N,
    }

    Dp = penultimate_features.shape[1]
    for i in range(Dp):
        data[f'penultimate_GNN_Feature{i + 1}'] = penultimate_features[:, i]

    for i in range(adjacency_eigvecs.shape[1]):
        data[f'Adj_EigVecs_Top{i + 1}'] = adjacency_eigvecs[:, i]

    df = pd.DataFrame(data)
    root_folder = "penultimate_GNN_Feature"
    subfolder_name = f"penultimate_GNN_Feature_nclasses_{n_classes}"
    output_filename = f"first_gnn_fromA_iter{iter}_j={args.J}_nlyr={args.num_layers}.xlsx"
    output_path = os.path.join(root_folder, subfolder_name, output_filename)
    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    # 你这里每个 iter 会写不同文件名；首个 iter 直接写，新 iter 也可以直接写（不用 append）
    df.to_excel(output_path, sheet_name=f'Iteration_{iter}', index=False)

    # === 8) 汇总指标 ===
    acc_sc_n, ari_sc_n, nmi_sc_n = res_spectral['normalized']['sc']
    acc_rf_n, ari_rf_n, nmi_rf_n = res_spectral['normalized']['refined']
    acc_sc_u, ari_sc_u, nmi_sc_u = res_spectral['unnormalized']['sc']
    acc_rf_u, ari_rf_u, nmi_rf_u = res_spectral['unnormalized']['refined']
    acc_sc_a, ari_sc_a, nmi_sc_a = res_spectral['adjacency']['sc']
    acc_rf_a, ari_rf_a, nmi_rf_a = res_spectral['adjacency']['refined']
    acc_sc_svt, ari_sc_svt, nmi_sc_svt = res_spectral['svt']['sc']
    acc_rf_svt, ari_rf_svt, nmi_rf_svt = res_spectral['svt']['refined']

    metrics = {
        "gnn_first": {"acc": float(acc_gnn_first), "ari": float(ari_gnn_first), "nmi": float(nmi_gnn_first)},
        "gnn_second": {"acc": float(acc_gnn_second), "ari": float(ari_gnn_second), "nmi": float(nmi_gnn_second)},
        "gnn_refined": {"acc": float(acc_gnn_refined), "ari": float(ari_gnn_refined), "nmi": float(nmi_gnn_refined)},
        "spectral_normalized": {"acc": float(acc_sc_n), "ari": float(ari_sc_n), "nmi": float(nmi_sc_n)},
        "spectral_normalized_refined": {"acc": float(acc_rf_n), "ari": float(ari_rf_n), "nmi": float(nmi_rf_n)},
        "spectral_unnormalized": {"acc": float(acc_sc_u), "ari": float(ari_sc_u), "nmi": float(nmi_sc_u)},
        "spectral_unnormalized_refined": {"acc": float(acc_rf_u), "ari": float(ari_rf_u), "nmi": float(nmi_rf_u)},
        "spectral_adjacency": {"acc": float(acc_sc_a), "ari": float(ari_sc_a), "nmi": float(nmi_sc_a)},
        "spectral_adjacency_refined": {"acc": float(acc_rf_a), "ari": float(ari_rf_a), "nmi": float(nmi_rf_a)},
        "spectral_svt": {"acc": float(acc_sc_svt), "ari": float(ari_sc_svt), "nmi": float(nmi_sc_svt)},
        "spectral_svt_refined": {"acc": float(acc_rf_svt), "ari": float(ari_rf_svt), "nmi": float(nmi_rf_svt)},
        "logistic": {"acc": float(acc_logistic), "ari": float(ari_logistic), "nmi": float(nmi_logistic)},
        "logistic_refined": {
            "acc": float(acc_logistic_refined),
            "ari": float(ari_logistic_refined),
            "nmi": float(nmi_logistic_refined),
        },
    }

    return loss_value, metrics



if __name__ == '__main__':
    try:
        setup_logger("run_gnn")
        gen = Generator()
        gen.N_train = args.N_train
        gen.N_test = args.N_test
        gen.N_val = args.N_val

        gen.edge_density = args.edge_density
        gen.p_SBM = args.p_SBM
        gen.q_SBM = args.q_SBM

        gen.random_noise = args.random_noise
        gen.noise = args.noise
        gen.noise_model = args.noise_model
        gen.generative_model = args.generative_model
        gen.n_classes = args.n_classes

        gen.num_examples_train = args.num_examples_train
        gen.num_examples_test = args.num_examples_test
        gen.num_examples_val = args.num_examples_val

        # 1. 创建总模型文件夹
        root_model_dir = "model_GNN"
        os.makedirs(root_model_dir, exist_ok=True)

        # 2. 创建子目录（按 n_classes 分类）
        folder_name = f"GNN_model_first_classes{args.n_classes}"
        full_save_dir = os.path.join(root_model_dir, folder_name)
        os.makedirs(full_save_dir, exist_ok=True)

        # 3. 构造保存路径
        filename_first = (
            f'gnn_J{args.J}_lyr{args.num_layers}_classes{args.n_classes}_numfeatures{args.num_features}'
        )
        path_first = os.path.join(full_save_dir, filename_first)

        filename_second = (
            f'local+refin_gnn_J{args.J_second}_lyr{args.num_layers_second}_classes{args.n_classes}_numfeatures{args.num_features_second}'
        )

        path_second = os.path.join(full_save_dir, filename_second)


        if args.mode == "train":
            ################################################################################################################
            # Here is the train period, prepare the dataloader we need to train
            gen.prepare_data()

            # 1. 准备并行线程数
            num_workers = min(4, multiprocessing.cpu_count() - 1)
            print("num_workers", num_workers)

            # 2. 创建 Dataset 实例
            train_dataset = SBMDataset(gen.data_train)
            val_dataset = SBMDataset(gen.data_val)
            test_dataset = SBMDataset(gen.data_test)

            # 3. 创建对应的 DataLoader
            train_loader = DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=num_workers,
                collate_fn=simple_collate_fn
            )

            val_loader = DataLoader(
                val_dataset,
                batch_size=args.batch_size,
                shuffle=False,  # ❗验证集不需要打乱顺序
                num_workers=num_workers,
                collate_fn=simple_collate_fn
            )

            test_loader = DataLoader(
                test_dataset,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=num_workers,
                collate_fn=simple_collate_fn

            )

            print(
                f"[阶段一] 训练 GNN：n_classes={args.n_classes}, layers={args.num_layers}, J = {args.J}")
            # 初始化并训练第一阶段
            if args.generative_model == 'SBM_multiclass':
                gnn_first_period = GNN_multiclass(args.num_features, args.num_layers, args.J + 3,
                                                  n_classes=args.n_classes)
            if torch.cuda.is_available():
                gnn_first_period = gnn_first_period.to(device)

            loss_list, acc_list = train_first_period_with_early_stopping(
                gnn_first_period, train_loader, val_loader, args.n_classes, args,
                epochs=100, patience=10, save_path = filename_first)


            print(f'Saving first-period GNN to {path_first}')
            torch.save(gnn_first_period.cpu(), path_first)

            if torch.cuda.is_available():
                gnn_first_period = gnn_first_period.to('cuda')

        print("[测试阶段] 开始...")
        print("训练逻辑回归模型...")
        lr_model = train_logistic_regression(gen, args.n_classes)

        gnn_first_period = torch.load(path_first, map_location=torch.device('cpu'))
        gnn_second_period = torch.load(path_second, map_location=torch.device('cpu'))


        print("[测试阶段] 开始...")

        # 1) 从 .mat 读取，统计 n1,n2，并计算/指定 p_in, p_out
        mat = loadmat("politicalblog_CC.mat")
        A_real = mat["A"]
        labels_real = mat["label"].squeeze()
        if labels_real.min() == 1:  # 映射到 {0,1}
            labels_real = labels_real - 1

        # 3) 送进你的测试函数（注意：若内部做 torch.tensor，会自动转）
        loss, metrics = test_single_first_period(
            gnn_first_period, gnn_second_period ,lr_model, A_real, labels_real, n_classes=args.n_classes, args=args, iter=0
        )

        df_metrics = pd.DataFrame(metrics).T
        print(df_metrics)

    except Exception as e:

        import traceback

        traceback.print_exc()