import argparse
import multiprocessing
import os
import random
import sys
import time
from tqdm import tqdm

import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader
import time, numpy as np
from torch import Tensor
from controlsnr import find_a_given_snr
from data_generator_sbm import Generator, simple_collate_fn
from load import get_gnn_inputs
from load_local_refinement import get_gnn_inputs_local_refinement
from losses import compute_loss_multiclass, compute_acc_ari_nmi, \
    gnn_compute_acc_ari_nmi_multiclass
from losses import from_scores_to_labels_multiclass_batch
from models import GNN_multiclass, GNN_multiclass_second_period
from spectral_clustering import spectral_clustering_adj
from train_first_period import train_first_period_with_early_stopping
from train_second_period import train_second_period_with_early_stopping
from controlsnr import solve_ab
from GNN_local_refinement import run_refinement_chain, run_multi_refinement

def setup_logger(prefix="main_gnn"):
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    pid = os.getpid()
    log_filename = f"{prefix}_{timestamp}_{pid}.log"
    logfile = open(log_filename, "w", buffering=1)
    sys.stdout = logfile
    sys.stderr = logfile
    print(f"[Logger initialized] Logging to: {log_filename}")

def load_best_model_into(model, ckpt_path, device):
    ckpt = torch.load(ckpt_path, map_location=device)
    if isinstance(ckpt, dict) and 'model_state' in ckpt:
        model.load_state_dict(ckpt['model_state'], strict=False)
        return model
    elif hasattr(ckpt, 'state_dict'):
        try:
            model.load_state_dict(ckpt.state_dict(), strict=False)
            return model
        except Exception:
            ckpt = ckpt.to(device)
            return ckpt
    else:
        model.load_state_dict(ckpt, strict=False)
        return model

def maybe_freeze(model, freeze=True):
    if freeze:
        for p in model.parameters():
            p.requires_grad = False
    model.eval()

parser = argparse.ArgumentParser()

###############################################################################
#                             General Settings                                #
#                          提前配置参数，方便后面使用                              #
###############################################################################

parser.add_argument('--num_examples_train', nargs='?', const=1, type=int,
                    default=int(6000))
parser.add_argument('--num_examples_test', nargs='?', const=1, type=int,
                    default=int(50))
parser.add_argument('--num_examples_val', nargs='?', const=1, type=int,
                    default=int(1000))
parser.add_argument('--edge_density', nargs='?', const=1, type=float,
                    default=0.2)
parser.add_argument('--p_SBM', nargs='?', const=1, type=float,
                    default=0.2)
parser.add_argument('--q_SBM', nargs='?', const=1, type=float,
                    default=0.1)
parser.add_argument('--class_sizes', type=int, nargs='+', default=[100, 1000],
                    help='List of class sizes for imbalanced SBM')
parser.add_argument('--random_noise', action='store_true')
parser.add_argument('--noise', nargs='?', const=1, type=float, default=2)
parser.add_argument('--noise_model', nargs='?', const=1, type=int, default=2)
#########################
#parser.add_argument('--generative_model', nargs='?', const=1, type=str,
#                    default='ErdosRenyi')
parser.add_argument('--generative_model', nargs='?', const=1, type=str,
                    default='SBM_multiclass')
parser.add_argument('--batch_size', nargs='?', const=1, type=int, default= 1)
parser.add_argument('--mode', nargs='?', const=1, type=str, default='test')
default_path = os.path.join(os.path.dirname(os.path.abspath(__file__)))
parser.add_argument('--mode_isbalanced', nargs='?', const=1, type=str, default='imbalanced')
parser.add_argument('--path_gnn', nargs='?', const=1, type=str, default=default_path)
parser.add_argument('--path_local_refinement', nargs='?', const=1, type=str, default='')

parser.add_argument('--filename_existing_gnn', nargs='?', const=1, type=str, default='')
parser.add_argument('--filename_existing_gnn_local_refinement', nargs='?', const=1, type=str, default='')

parser.add_argument('--print_freq', nargs='?', const=1, type=int, default=1)
parser.add_argument('--test_freq', nargs='?', const=1, type=int, default=500)
parser.add_argument('--save_freq', nargs='?', const=1, type=int, default=2000)
parser.add_argument('--clip_grad_norm', nargs='?', const=1, type=float,
                    default=10.0)
parser.add_argument('--freeze_bn', dest='eval_vs_train', action='store_true')
parser.set_defaults(eval_vs_train=True)

###############################################################################
#                                 GNN Settings                                #
###############################################################################

###############################################################################
#                                 GNN first period                            #
###############################################################################
parser.add_argument('--num_features', nargs='?', const=1, type=int,
                    default=16)
parser.add_argument('--num_layers', nargs='?', const=1, type=int,
                    default=30)
parser.add_argument('--J', nargs='?', const=1, type=int, default= 2)

###############################################################################
#                                 GNN second period                            #
###############################################################################
parser.add_argument('--num_features_second', nargs='?', const=1, type=int,
                    default= 8)
parser.add_argument('--num_layers_second', nargs='?', const=1, type=int,
                    default= 3)
parser.add_argument('--J_second', nargs='?', const=1, type=int, default= 1)

parser.add_argument('--n_classes', nargs='?', const=1, type=int,
                    default=2)
parser.add_argument('--N_train', nargs='?', const=1, type=int, default=1000)
parser.add_argument('--N_test', nargs='?', const=1, type=int, default=1000)
parser.add_argument('--N_val', nargs='?', const=1, type=int, default=1000)

parser.add_argument('--lr', nargs='?', const=1, type=float, default=4e-3)

args = parser.parse_args()

if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
    dtype_l = torch.cuda.LongTensor
else:
    dtype = torch.FloatTensor
    dtype_l = torch.LongTensor

batch_size = args.batch_size
criterion = nn.CrossEntropyLoss()
template1 = '{:<10} {:<10} {:<10} {:<15} {:<10} {:<10} {:<10}'
template2 = '{:<10} {:<10.5f} {:<10.5f} {:<15} {:<10} {:<10} {:<10.3f} \n'

template_header = '{:<6} {:<10} {:<10} {:<13} {:<10} {:<8} {:<10} {:<10} {:<20}'
template_row    = '{:<6} {:<10.5f} {:<10.5f} {:<13} {:<10} {:<8} {:<10.3f} {:<10.4f} {:<20}'

class SBMDataset(torch.utils.data.Dataset):
    def __init__(self, npz_file_list):
        self.files = npz_file_list
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        data = np.load(self.files[idx])
        adj = csr_matrix((data['adj_data'], data['adj_indices'], data['adj_indptr']), shape=tuple(data['adj_shape']))
        labels = data['labels']
        return {'adj': adj, 'labels': labels, 'num_nodes': adj.shape[0]}

# Sparse helpers
from scipy.sparse import csr_matrix
from spectral_clustering import local_refinement_by_neighbors

import pandas as pd
import numpy as np

def get_available_device():
    for i in range(torch.cuda.device_count()):
        try:
            torch.cuda.set_device(i)
            torch.zeros(1).cuda()
            return torch.device(f"cuda:{i}")
        except RuntimeError:
            continue
    return torch.device("cpu")

device = get_available_device()

def test_single_first_period(gnn_first_period, gnn_second_period, gen, n_classes, args, iter, mode='balanced', class_sizes=None):
    def _sync():
        if torch.cuda.is_available():
            torch.cuda.synchronize()

    t_all0 = time.perf_counter()
    gnn_first_period.train()
    gnn_second_period.train()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ===== 1) 生成图 =====
    if mode == 'imbalanced':
        if random.random() < 0.5:
            random_class_sizes = class_sizes[::-1]
        else:
            random_class_sizes = class_sizes
        W, true_labels, eigvecs_top = gen.imbalanced_sample_otf_single(random_class_sizes, is_training=True, cuda=True)
        true_labels = true_labels.type(dtype_l)
    else:
        W, true_labels, eigvecs_top = gen.sample_otf_single(is_training=True, cuda=True)
        true_labels = true_labels.type(dtype_l)

    W_np = W.detach().cpu().numpy() if torch.is_tensor(W) else np.asarray(W)

    # ===== 2) 谱聚类 =====
    # res_spectral = spectral_clustering_adj(
    #     W_np, n_classes, true_labels,
    #     normalized=True,
    #     run_all=True,
    #     random_state=0
    # )

    # ===== 3) 构造 GNN 输入 =====
    W = W.to(device)  # shape: (B, N, N)
    # W_np = W.detach().cpu().numpy() if torch.is_tensor(W) else np.asarray(W)

    WW, x = get_gnn_inputs(W.detach().cpu().numpy(), args.J)   # 输出：WW: (B, N, N, J+3), x: (B, N, d)
    WW = WW.clone().detach().to(torch.float32).to(device)
    x = x.clone().detach().to(torch.float32).to(device)

    # ===== 4) 第一阶段 GNN 前向 =====
    with torch.no_grad():
        _sync()
        pred_single_first = gnn_first_period(WW, x)   # (B, N, n_classes)
        _sync()
        start_x_label = from_scores_to_labels_multiclass_batch(
            pred_single_first.detach().cpu().numpy()
        )

    # ===== 5) Refinement chain =====
    with torch.no_grad():
        ref = run_refinement_chain(
            gnn_second_period=gnn_second_period,
            W_np=W_np,                      # 你的实现接受 torch 或 numpy 都行
            init_labels=start_x_label,
            true_labels=true_labels,
            args=args,
            device=device,
            total_iters=10,
            verbose=False
        )
        _sync()

    if ref.get("first_iter") is not None:
        gnn_pred_label_second = ref["first_iter"]["pred_label"]
        pred_single_second    = ref["first_iter"]["pred"]
        acc_gnn_second        = ref["first_iter"]["acc"]
        ari_gnn_second        = ref["first_iter"]["ari"]
        nmi_gnn_second        = ref["first_iter"]["nmi"]

    acc_gnn_final = ref["final"]["acc"]
    ari_gnn_final = ref["final"]["ari"]
    nmi_gnn_final = ref["final"]["nmi"]

    # ===== 6) loss & 指标 =====
    W_np = W.detach().cpu().numpy() if isinstance(W, Tensor) else np.asarray(W)

    loss_test_first = compute_loss_multiclass(pred_single_first, true_labels, n_classes)
    _sync()
    acc_gnn_first, best_matched_pred, ari_gnn_first, nmi_gnn_first = gnn_compute_acc_ari_nmi_multiclass(
        pred_single_first, true_labels, n_classes
    )
    gnn_pred_label_first = best_matched_pred

    # ===== 7) 邻居多数投票本地 refinement =====
    gnn_refined = local_refinement_by_neighbors(W_np, gnn_pred_label_first, n_classes)
    acc_gnn_refined, best_matched_pred, ari_gnn_refined, nmi_gnn_refined = compute_acc_ari_nmi(
        gnn_refined, true_labels, n_classes
    )
    gnn_refined_pred = best_matched_pred

    # ===== 8) 总结 & 打印（仅总时间）=====
    total_elapsed = time.perf_counter() - t_all0
    loss_value = float(loss_test_first.data.detach().cpu().numpy()) if torch.cuda.is_available() else float(loss_test_first.data.numpy())

    info = ['iter', 'avg loss', 'avg acc', 'edge_density', 'noise', 'model', 'elapsed']
    out  = [iter, loss_value, acc_gnn_first, args.edge_density, args.noise, 'GNN', total_elapsed]
    print(template1.format(*info))
    print(template2.format(*out))
    # 去掉了分阶段的 [TIMING] 打印

    del WW, x

    # ——谱聚类结果拆分——
    # acc_sc_n, ari_sc_n, nmi_sc_n = res_spectral['normalized']['sc']
    # acc_rf_n, ari_rf_n, nmi_rf_n = res_spectral['normalized']['refined']
    #
    # acc_sc_u, ari_sc_u, nmi_sc_u = res_spectral['unnormalized']['sc']
    # acc_rf_u, ari_rf_u, nmi_rf_u = res_spectral['unnormalized']['refined']
    #
    # acc_sc_a, ari_sc_a, nmi_sc_a = res_spectral['adjacency']['sc']
    # acc_rf_a, ari_rf_a, nmi_rf_a = res_spectral['adjacency']['refined']

    metrics = {
        "gnn": {"acc": float(acc_gnn_first), "ari": float(ari_gnn_first), "nmi": float(nmi_gnn_first)},
        "gnn_second": {"acc": float(acc_gnn_second), "ari": float(ari_gnn_second), "nmi": float(nmi_gnn_second)},
        "gnn_final":  {"acc": float(acc_gnn_final),  "ari": float(ari_gnn_final),  "nmi": float(nmi_gnn_final)},
        "gnn_refined": {"acc": float(acc_gnn_refined), "ari": float(ari_gnn_refined), "nmi": float(nmi_gnn_refined)},
        # "spectral_normalized": {"acc": float(acc_sc_n), "ari": float(ari_sc_n), "nmi": float(nmi_sc_n)},
        # "spectral_normalized_refined": {"acc": float(acc_rf_n), "ari": float(ari_rf_n), "nmi": float(nmi_rf_n)},
        #
        # "spectral_unnormalized": {"acc": float(acc_sc_u), "ari": float(ari_sc_u), "nmi": float(nmi_sc_u)},
        # "spectral_unnormalized_refined": {"acc": float(acc_rf_u), "ari": float(ari_rf_u), "nmi": float(nmi_rf_u)},
        #
        # "spectral_adjacency": {"acc": float(acc_sc_a), "ari": float(ari_sc_a), "nmi": float(nmi_sc_a)},
        # "spectral_adjacency_refined": {"acc": float(acc_rf_a), "ari": float(ari_rf_a), "nmi": float(nmi_rf_a)},
    }

    return float(loss_test_first), metrics


METHODS = [
    "gnn", "gnn_second", "gnn_final", "gnn_refined",
    "spectral_normalized", "spectral_normalized_refined",
    "spectral_unnormalized", "spectral_unnormalized_refined",
    "spectral_adjacency", "spectral_adjacency_refined",
]

def test_first_period(
    gnn_first_period,
    gnn_second_period,
    n_classes,
    gen,
    args,
    iters=None,
    mode='balanced',
    class_sizes=None,
):
    if iters is None:
        iters = args.num_examples_test

    gnn_first_period.train()
    buckets = {m: {"acc": [], "ari": [], "nmi": []} for m in METHODS}

    for it in range(iters):
        loss_val, metrics = test_single_first_period(
            gnn_first_period=gnn_first_period,
            gnn_second_period=gnn_second_period,
            gen=gen,
            n_classes=n_classes,
            args=args,
            iter=it,
            mode=mode,
            class_sizes=class_sizes,
        )

        for meth in METHODS:
            md = metrics.get(meth, {})
            for k in ("acc", "ari", "nmi"):
                v = md.get(k, np.nan)
                buckets[meth][k].append(float(v) if np.isfinite(v) else np.nan)

        torch.cuda.empty_cache()

    def mean(vals):
        arr = np.asarray(vals, dtype=float)
        return float(np.nanmean(arr))

    n = getattr(args, "N_train", getattr(gen, "N_train", None)) or getattr(args, "N_test", getattr(gen, "N_test", -1))
    logn_div_n = np.log(n) / n if n and n > 0 else np.nan
    a = gen.p_SBM / logn_div_n if np.isfinite(logn_div_n) else np.nan
    b = gen.q_SBM / logn_div_n if np.isfinite(logn_div_n) else np.nan
    k = n_classes
    snr = (a - b) ** 2 / (k * (a + (k - 1) * b)) if np.all(np.isfinite([a, b])) else np.nan

    meta = {
        "n_classes": int(n_classes),
        "class_sizes": str(class_sizes) if class_sizes is not None else "",
        "p_SBM": float(gen.p_SBM),
        "q_SBM": float(gen.q_SBM),
        "J": int(args.J),
        "N_train": int(getattr(args, "N_train", getattr(gen, "N_train", -1))),
        "N_test": int(getattr(args, "N_test", getattr(gen, "N_test", -1))),
        "SNR": float(snr) if np.isfinite(snr) else np.nan,
    }

    row_acc = {**meta}
    row_ari = {**meta}
    row_nmi = {**meta}
    for meth in METHODS:
        row_acc[meth] = mean(buckets[meth]["acc"])
        row_ari[meth] = mean(buckets[meth]["ari"])
        row_nmi[meth] = mean(buckets[meth]["nmi"])

    return row_acc, row_ari, row_nmi

from pathlib import Path

def append_rows_to_excel(row_acc: dict, row_ari: dict, row_nmi: dict, filename="summary.xlsx", extra_info: dict=None):
    def _normalize_extra(ei: dict):
        if not ei:
            return {}
        ei = dict(ei)
        if "class_sizes" in ei:
            cs = ei["class_sizes"]
            if isinstance(cs, (list, tuple, np.ndarray)):
                ei["class_sizes"] = "-".join(map(str, cs))
            else:
                ei["class_sizes"] = str(cs)
        return ei

    def _append_row(sheet_name: str, row: dict):
        merged = dict(row)
        merged.update(_normalize_extra(extra_info))
        df_new = pd.DataFrame([merged])
        path = Path(filename)

        if path.exists():
            try:
                df_old = pd.read_excel(path, sheet_name=sheet_name)
                cols = list(dict.fromkeys(list(df_old.columns) + list(df_new.columns)))
                df_old = df_old.reindex(columns=cols)
                df_new = df_new.reindex(columns=cols)
                df_out = pd.concat([df_old, df_new], ignore_index=True)
            except ValueError:
                df_out = df_new
            with pd.ExcelWriter(filename, engine="openpyxl", mode="a", if_sheet_exists="replace") as writer:
                df_out.to_excel(writer, sheet_name=sheet_name, index=False)
        else:
            with pd.ExcelWriter(filename, engine="openpyxl", mode="w") as writer:
                df_new.to_excel(writer, sheet_name=sheet_name, index=False)

    _append_row("ACC", row_acc)
    _append_row("ARI", row_ari)
    _append_row("NMI", row_nmi)

def test_first_period_wrapper(args_tuple):
    gnn_first_period, gnn_second_period, class_sizes, snr, gen, logN_div_N, total_ab = args_tuple
    a, b = solve_ab(args.N_train, class_sizes, args.n_classes, snr, total_ab)
    p_SBM = round(a * logN_div_N, 4)
    q_SBM = round(b * logN_div_N, 4)

    gen_local = gen.copy()
    gen_local.p_SBM = p_SBM
    gen_local.q_SBM = q_SBM

    print(f"\n[测试阶段] class_sizes: {class_sizes}, SNR: {snr:.2f}")
    print(f"使用的 SBM 参数: p={p_SBM}, q={q_SBM}")

    row_acc, row_ari, row_nmi = test_first_period(
        gnn_first_period=gnn_first_period,
        gnn_second_period=gnn_second_period,
        n_classes=args.n_classes,
        gen=gen_local,
        args=args,
        iters=args.num_examples_test,
        mode=args.mode_isbalanced,
        class_sizes=class_sizes,
    )

    return row_acc, row_ari, row_nmi

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

if __name__ == '__main__':
    try:
        setup_logger("run_gnn")
        gen = Generator()
        gen.N_train = args.N_train
        gen.N_test = args.N_test
        gen.N_val = args.N_val
        gen.edge_density = args.edge_density
        gen.p_SBM = args.p_SBM
        gen.q_SBM = args.q_SBM
        gen.random_noise = args.random_noise
        gen.noise = args.noise
        gen.noise_model = args.noise_model
        gen.generative_model = args.generative_model
        gen.n_classes = args.n_classes
        gen.num_examples_train = args.num_examples_train
        gen.num_examples_test = args.num_examples_test
        gen.num_examples_val = args.num_examples_val

        root_model_dir = "model_GNN"
        os.makedirs(root_model_dir, exist_ok=True)
        folder_name = f"GNN_model_first_classes{args.n_classes}"
        full_save_dir = os.path.join(root_model_dir, folder_name)
        os.makedirs(full_save_dir, exist_ok=True)

        filename_first = f'gnn_J{args.J}_lyr{args.num_layers}_classes{args.n_classes}_numfeatures{args.num_features}'
        path_first = os.path.join(full_save_dir, filename_first)

        filename_second = f'local+refin_gnn_J{args.J_second}_lyr{args.num_layers_second}_classes{args.n_classes}_numfeatures{args.num_features_second}'
        path_second = os.path.join(full_save_dir, filename_second)

        if args.mode == "train":
            gen.prepare_data()

            num_workers = min(4, multiprocessing.cpu_count() - 1)
            print("num_workers", num_workers)

            train_dataset = SBMDataset(gen.data_train)
            val_dataset = SBMDataset(gen.data_val)
            test_dataset = SBMDataset(gen.data_test)

            train_loader = DataLoader(
                train_dataset, batch_size=args.batch_size, shuffle=True,
                num_workers=num_workers, collate_fn=simple_collate_fn
            )
            val_loader = DataLoader(
                val_dataset, batch_size=args.batch_size, shuffle=False,
                num_workers=num_workers, collate_fn=simple_collate_fn
            )
            test_loader = DataLoader(
                test_dataset, batch_size=args.batch_size, shuffle=False,
                num_workers=num_workers, collate_fn=simple_collate_fn
            )

            print(f"[阶段一] 训练 GNN：n_classes={args.n_classes}, layers={args.num_layers}, J={args.J}, num_features={args.num_features}")
            if os.path.exists(path_first):
                print(f"[阶段一] 检测到已有‘最优权重’，直接载入: {path_first}")
                if args.generative_model == 'SBM_multiclass':
                    gnn_first_period = GNN_multiclass(args.num_features, args.num_layers, args.J + 3, n_classes=args.n_classes)
                gnn_first_period = gnn_first_period.to(device)
                gnn_first_period = load_best_model_into(gnn_first_period, path_first, device)
            else:
                if args.generative_model == 'SBM_multiclass':
                    gnn_first_period = GNN_multiclass(args.num_features, args.num_layers, args.J + 3, n_classes=args.n_classes)
                if torch.cuda.is_available():
                    gnn_first_period = gnn_first_period.to(device)

                loss_list, acc_list = train_first_period_with_early_stopping(
                    gnn_first_period, train_loader, val_loader, args.n_classes, args,
                    epochs=100, patience=5, save_path=path_first, filename=filename_first
                )
                print(f"[阶段一] 从最优权重回载到内存: {path_first}")
                gnn_first_period = load_best_model_into(gnn_first_period, path_first, device)

            maybe_freeze(gnn_first_period, freeze=True)

            print(f"[阶段二] 训练 GNN：n_classes={args.n_classes}, layers={args.num_layers_second}, J={args.J_second}, num_features={args.num_features_second}")
            if os.path.exists(path_second):
                print(f"[阶段二] 检测到已有模型文件，直接载入: {path_second}")
                gnn_second_period = torch.load(path_second, map_location=device)
                gnn_second_period = gnn_second_period.to(device)
            else:
                if args.generative_model == 'SBM_multiclass':
                    gnn_second_period = GNN_multiclass_second_period(
                        args.num_features_second, args.num_layers_second, args.J_second + 3, n_classes=args.n_classes
                    )
                if torch.cuda.is_available():
                    gnn_second_period = gnn_second_period.to(device)

                loss_list, acc_list = train_second_period_with_early_stopping(
                    gnn_first_period, gnn_second_period, train_loader, val_loader,
                    args.n_classes, args,
                    epochs=100, patience=5, save_path=path_second, filename=filename_second
                )
                print(f"[阶段二] 保存模型到 {path_second}")

        print("[测试阶段] 开始...")

        # 模型加载到当前 device
        gnn_first_period = torch.load(path_first, map_location=device)
        gnn_first_period = gnn_first_period.to(device)
        gnn_second_period = torch.load(path_second, map_location=device)
        gnn_second_period = gnn_second_period.to(device)

        print("[测试阶段] 开始...")
        class_sizes_list = [
            [200, 800],
            [300, 700],
            [400, 600],
            [500, 500]
        ]
        snr_list = [0.5, 1 ,1.5, 2]
        total_ab_list = [5, 10, 15]

        N = 1000
        logN_div_N = np.log(N) / N

        task_args = []
        for total_ab in total_ab_list:
            for class_sizes in class_sizes_list:
                for snr in snr_list:
                    task_args.append((gnn_first_period, gnn_second_period, class_sizes, snr, gen, logN_div_N, total_ab))

        print("[测试阶段] 顺序执行（GPU 单卡逐个任务）...")
        results = []
        for targs in tqdm(task_args, total=len(task_args), desc="[测试] 进度(顺序)"):
            r = test_first_period_wrapper(targs)
            results.append(r)
            torch.cuda.empty_cache()

        for (row_acc, row_ari, row_nmi), (gnn_fp, gnn_fp_second, class_sizes, snr, gen, logN_div_N, total_ab) in zip(results, task_args):
            filename = f"summary_total{total_ab}.xlsx"
            append_rows_to_excel(
                row_acc, row_ari, row_nmi,
                filename=filename,
                extra_info = {"class_sizes": class_sizes, "snr": snr, "total_ab": total_ab}
            )

    except Exception as e:
        import traceback
        traceback.print_exc()
