import argparse
import multiprocessing
import os
import random
import sys
import time

import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader

from GNN_local_refinement import run_refinement_chain
from controlsnr import solve_ab
from data_generator_sbm import Generator, simple_collate_fn
from load import get_gnn_inputs
from losses import compute_loss_multiclass, compute_acc_ari_nmi, \
    gnn_compute_acc_ari_nmi_multiclass
from losses import from_scores_to_labels_multiclass_batch
from models import GNN_multiclass, GNN_multiclass_second_period
from train_first_period import train_first_period_with_early_stopping
from train_second_period import train_second_period_with_early_stopping


def setup_logger(prefix="main_gnn"):
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    pid = os.getpid()
    log_filename = f"{prefix}_{timestamp}_{pid}.log"
    logfile = open(log_filename, "w", buffering=1)
    sys.stdout = logfile
    sys.stderr = logfile
    print(f"[Logger initialized] Logging to: {log_filename}")

def load_best_model_into(model, ckpt_path, device):
    ckpt = torch.load(ckpt_path, map_location=device)
    if isinstance(ckpt, dict) and 'model_state' in ckpt:
        model.load_state_dict(ckpt['model_state'], strict=False)
        return model
    elif hasattr(ckpt, 'state_dict'):
        try:
            model.load_state_dict(ckpt.state_dict(), strict=False)
            return model
        except Exception:
            ckpt = ckpt.to(device)
            return ckpt
    else:
        model.load_state_dict(ckpt, strict=False)
        return model

def maybe_freeze(model, freeze=True):
    if freeze:
        for p in model.parameters():
            p.requires_grad = False
    model.eval()

parser = argparse.ArgumentParser()

###############################################################################
#                             General Settings                                #
#                          提前配置参数，方便后面使用                              #
###############################################################################

parser.add_argument('--num_examples_train', nargs='?', const=1, type=int,
                    default=int(6000))
parser.add_argument('--num_examples_test', nargs='?', const=1, type=int,
                    default=int(50))
parser.add_argument('--num_examples_val', nargs='?', const=1, type=int,
                    default=int(1000))
parser.add_argument('--edge_density', nargs='?', const=1, type=float,
                    default=0.2)
parser.add_argument('--p_SBM', nargs='?', const=1, type=float,
                    default=0.2)
parser.add_argument('--q_SBM', nargs='?', const=1, type=float,
                    default=0.1)
parser.add_argument('--class_sizes', type=int, nargs='+', default=[100, 1000],
                    help='List of class sizes for imbalanced SBM')
parser.add_argument('--random_noise', action='store_true')
parser.add_argument('--noise', nargs='?', const=1, type=float, default=2)
parser.add_argument('--noise_model', nargs='?', const=1, type=int, default=2)
#########################
#parser.add_argument('--generative_model', nargs='?', const=1, type=str,
#                    default='ErdosRenyi')
parser.add_argument('--generative_model', nargs='?', const=1, type=str,
                    default='SBM_multiclass')
parser.add_argument('--batch_size', nargs='?', const=1, type=int, default= 1)
parser.add_argument('--mode', nargs='?', const=1, type=str, default='test')
default_path = os.path.join(os.path.dirname(os.path.abspath(__file__)))
parser.add_argument('--mode_isbalanced', nargs='?', const=1, type=str, default='imbalanced')
parser.add_argument('--path_gnn', nargs='?', const=1, type=str, default=default_path)
parser.add_argument('--path_local_refinement', nargs='?', const=1, type=str, default='')

parser.add_argument('--filename_existing_gnn', nargs='?', const=1, type=str, default='')
parser.add_argument('--filename_existing_gnn_local_refinement', nargs='?', const=1, type=str, default='')

parser.add_argument('--print_freq', nargs='?', const=1, type=int, default=1)
parser.add_argument('--test_freq', nargs='?', const=1, type=int, default=500)
parser.add_argument('--save_freq', nargs='?', const=1, type=int, default=2000)
parser.add_argument('--clip_grad_norm', nargs='?', const=1, type=float,
                    default=10.0)
parser.add_argument('--freeze_bn', dest='eval_vs_train', action='store_true')
parser.set_defaults(eval_vs_train=True)

###############################################################################
#                                 GNN Settings                                #
###############################################################################

###############################################################################
#                                 GNN first period                            #
###############################################################################
parser.add_argument('--num_features', nargs='?', const=1, type=int,
                    default=16)
parser.add_argument('--num_layers', nargs='?', const=1, type=int,
                    default=30)
parser.add_argument('--J', nargs='?', const=1, type=int, default= 2)

###############################################################################
#                                 GNN second period                            #
###############################################################################
parser.add_argument('--num_features_second', nargs='?', const=1, type=int,
                    default= 8)
parser.add_argument('--num_layers_second', nargs='?', const=1, type=int,
                    default= 3)
parser.add_argument('--J_second', nargs='?', const=1, type=int, default= 1)

parser.add_argument('--n_classes', nargs='?', const=1, type=int,
                    default=2)
parser.add_argument('--N_train', nargs='?', const=1, type=int, default=1000)
parser.add_argument('--N_test', nargs='?', const=1, type=int, default=1000)
parser.add_argument('--N_val', nargs='?', const=1, type=int, default=1000)

parser.add_argument('--lr', nargs='?', const=1, type=float, default=4e-3)

args = parser.parse_args()

if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
    dtype_l = torch.cuda.LongTensor
else:
    dtype = torch.FloatTensor
    dtype_l = torch.LongTensor

batch_size = args.batch_size
criterion = nn.CrossEntropyLoss()
template1 = '{:<10} {:<10} {:<10} {:<15} {:<10} {:<10} {:<10}'
template2 = '{:<10} {:<10.5f} {:<10.5f} {:<15} {:<10} {:<10} {:<10.3f} \n'

template_header = '{:<6} {:<10} {:<10} {:<13} {:<10} {:<8} {:<10} {:<10} {:<20}'
template_row    = '{:<6} {:<10.5f} {:<10.5f} {:<13} {:<10} {:<8} {:<10.3f} {:<10.4f} {:<20}'

class SBMDataset(torch.utils.data.Dataset):
    def __init__(self, npz_file_list):
        self.files = npz_file_list
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        data = np.load(self.files[idx])
        adj = csr_matrix((data['adj_data'], data['adj_indices'], data['adj_indptr']), shape=tuple(data['adj_shape']))
        labels = data['labels']
        return {'adj': adj, 'labels': labels, 'num_nodes': adj.shape[0]}

# Sparse helpers
from scipy.sparse import csr_matrix
from spectral_clustering import local_refinement_by_neighbors

import pandas as pd
import numpy as np

def get_available_device():
    for i in range(torch.cuda.device_count()):
        try:
            torch.cuda.set_device(i)
            torch.zeros(1).cuda()
            return torch.device(f"cuda:{i}")
        except RuntimeError:
            continue
    return torch.device("cpu")

device = get_available_device()

def test_single_first_period(gnn_first_period, gnn_second_period, gen, n_classes, args, iter, mode='balanced', class_sizes=None):
    def _sync():
        if torch.cuda.is_available():
            torch.cuda.synchronize()

    t_all0 = time.perf_counter()
    gnn_first_period.train()
    gnn_second_period.train()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ===== 1) 生成图 =====
    if mode == 'imbalanced':
        if random.random() < 0.5:
            random_class_sizes = class_sizes[::-1]
        else:
            random_class_sizes = class_sizes
        W, true_labels, eigvecs_top = gen.imbalanced_sample_otf_single(random_class_sizes, is_training=True, cuda=True)
        true_labels = true_labels.type(dtype_l)
    else:
        W, true_labels, eigvecs_top = gen.sample_otf_single(is_training=True, cuda=True)
        true_labels = true_labels.type(dtype_l)

    W_np = W.detach().cpu().numpy() if torch.is_tensor(W) else np.asarray(W)

    # ===== 2) 谱聚类 =====
    # res_spectral = spectral_clustering_adj(
    #     W_np, n_classes, true_labels,
    #     normalized=True,
    #     run_all=True,
    #     random_state=0
    # )

    # ===== 3) 构造 GNN 输入 =====
    W = W.to(device)  # shape: (B, N, N)
    # W_np = W.detach().cpu().numpy() if torch.is_tensor(W) else np.asarray(W)

    WW, x = get_gnn_inputs(W.detach().cpu().numpy(), args.J)   # 输出：WW: (B, N, N, J+3), x: (B, N, d)
    WW = WW.clone().detach().to(torch.float32).to(device)
    x = x.clone().detach().to(torch.float32).to(device)

    # ===== 4) 第一阶段 GNN 前向 =====
    with torch.no_grad():
        _sync()
        pred_single_first = gnn_first_period(WW, x)   # (B, N, n_classes)
        _sync()
        start_x_label = from_scores_to_labels_multiclass_batch(
            pred_single_first.detach().cpu().numpy()
        )

    # ===== 5) Refinement chain =====
    with torch.no_grad():
        ref = run_refinement_chain(
            gnn_second_period=gnn_second_period,
            W_np=W_np,                      # 你的实现接受 torch 或 numpy 都行
            init_labels=start_x_label,
            true_labels=true_labels,
            args=args,
            device=device,
            total_iters=10,
            verbose=False
        )
        _sync()

    if ref.get("first_iter") is not None:
        gnn_pred_label_second = ref["first_iter"]["pred_label"]
        pred_single_second    = ref["first_iter"]["pred"]
        acc_gnn_second        = ref["first_iter"]["acc"]
        ari_gnn_second        = ref["first_iter"]["ari"]
        nmi_gnn_second        = ref["first_iter"]["nmi"]

    acc_gnn_final = ref["final"]["acc"]
    ari_gnn_final = ref["final"]["ari"]
    nmi_gnn_final = ref["final"]["nmi"]

    # ===== 6) loss & 指标 =====
    W_np = W.detach().cpu().numpy() if isinstance(W, Tensor) else np.asarray(W)

    loss_test_first = compute_loss_multiclass(pred_single_first, true_labels, n_classes)
    _sync()
    acc_gnn_first, best_matched_pred, ari_gnn_first, nmi_gnn_first = gnn_compute_acc_ari_nmi_multiclass(
        pred_single_first, true_labels, n_classes
    )
    gnn_pred_label_first = best_matched_pred

    # ===== 7) 邻居多数投票本地 refinement =====
    gnn_refined = local_refinement_by_neighbors(W_np, gnn_pred_label_first, n_classes)
    acc_gnn_refined, best_matched_pred, ari_gnn_refined, nmi_gnn_refined = compute_acc_ari_nmi(
        gnn_refined, true_labels, n_classes
    )
    gnn_refined_pred = best_matched_pred

    # ===== 8) 总结 & 打印（仅总时间）=====
    total_elapsed = time.perf_counter() - t_all0
    loss_value = float(loss_test_first.data.detach().cpu().numpy()) if torch.cuda.is_available() else float(loss_test_first.data.numpy())

    info = ['iter', 'avg loss', 'avg acc', 'edge_density', 'noise', 'model', 'elapsed']
    out  = [iter, loss_value, acc_gnn_first, args.edge_density, args.noise, 'GNN', total_elapsed]
    print(template1.format(*info))
    print(template2.format(*out))
    # 去掉了分阶段的 [TIMING] 打印

    del WW, x


    metrics = {
        "gnn": {"acc": float(acc_gnn_first), "ari": float(ari_gnn_first), "nmi": float(nmi_gnn_first)},
        "gnn_second": {"acc": float(acc_gnn_second), "ari": float(ari_gnn_second), "nmi": float(nmi_gnn_second)},
        "gnn_final":  {"acc": float(acc_gnn_final),  "ari": float(ari_gnn_final),  "nmi": float(nmi_gnn_final)},
        "gnn_refined": {"acc": float(acc_gnn_refined), "ari": float(ari_gnn_refined), "nmi": float(nmi_gnn_refined)},
    }

    return float(loss_test_first), metrics


METHODS = [
    "gnn", "gnn_second", "gnn_final", "gnn_refined",
    "spectral_normalized", "spectral_normalized_refined",
    "spectral_unnormalized", "spectral_unnormalized_refined",
    "spectral_adjacency", "spectral_adjacency_refined",
]

def test_first_period(
    gnn_first_period,
    gnn_second_period,
    n_classes,
    gen,
    args,
    iters=None,
    mode='balanced',
    class_sizes=None,
):
    if iters is None:
        iters = args.num_examples_test

    gnn_first_period.train()
    buckets = {m: {"acc": [], "ari": [], "nmi": []} for m in METHODS}

    for it in range(iters):
        loss_val, metrics = test_single_first_period(
            gnn_first_period=gnn_first_period,
            gnn_second_period=gnn_second_period,
            gen=gen,
            n_classes=n_classes,
            args=args,
            iter=it,
            mode=mode,
            class_sizes=class_sizes,
        )

        for meth in METHODS:
            md = metrics.get(meth, {})
            for k in ("acc", "ari", "nmi"):
                v = md.get(k, np.nan)
                buckets[meth][k].append(float(v) if np.isfinite(v) else np.nan)

        torch.cuda.empty_cache()

    def mean(vals):
        arr = np.asarray(vals, dtype=float)
        return float(np.nanmean(arr))

    n = getattr(args, "N_train", getattr(gen, "N_train", None)) or getattr(args, "N_test", getattr(gen, "N_test", -1))
    logn_div_n = np.log(n) / n if n and n > 0 else np.nan
    a = gen.p_SBM / logn_div_n if np.isfinite(logn_div_n) else np.nan
    b = gen.q_SBM / logn_div_n if np.isfinite(logn_div_n) else np.nan
    k = n_classes
    snr = (a - b) ** 2 / (k * (a + (k - 1) * b)) if np.all(np.isfinite([a, b])) else np.nan

    meta = {
        "n_classes": int(n_classes),
        "class_sizes": str(class_sizes) if class_sizes is not None else "",
        "p_SBM": float(gen.p_SBM),
        "q_SBM": float(gen.q_SBM),
        "J": int(args.J),
        "N_train": int(getattr(args, "N_train", getattr(gen, "N_train", -1))),
        "N_test": int(getattr(args, "N_test", getattr(gen, "N_test", -1))),
        "SNR": float(snr) if np.isfinite(snr) else np.nan,
    }

    row_acc = {**meta}
    row_ari = {**meta}
    row_nmi = {**meta}
    for meth in METHODS:
        row_acc[meth] = mean(buckets[meth]["acc"])
        row_ari[meth] = mean(buckets[meth]["ari"])
        row_nmi[meth] = mean(buckets[meth]["nmi"])

    return row_acc, row_ari, row_nmi

from pathlib import Path

def append_rows_to_excel(row_acc: dict, row_ari: dict, row_nmi: dict, filename="summary.xlsx", extra_info: dict=None):
    def _normalize_extra(ei: dict):
        if not ei:
            return {}
        ei = dict(ei)
        if "class_sizes" in ei:
            cs = ei["class_sizes"]
            if isinstance(cs, (list, tuple, np.ndarray)):
                ei["class_sizes"] = "-".join(map(str, cs))
            else:
                ei["class_sizes"] = str(cs)
        return ei

    def _append_row(sheet_name: str, row: dict):
        merged = dict(row)
        merged.update(_normalize_extra(extra_info))
        df_new = pd.DataFrame([merged])
        path = Path(filename)

        if path.exists():
            try:
                df_old = pd.read_excel(path, sheet_name=sheet_name)
                cols = list(dict.fromkeys(list(df_old.columns) + list(df_new.columns)))
                df_old = df_old.reindex(columns=cols)
                df_new = df_new.reindex(columns=cols)
                df_out = pd.concat([df_old, df_new], ignore_index=True)
            except ValueError:
                df_out = df_new
            with pd.ExcelWriter(filename, engine="openpyxl", mode="a", if_sheet_exists="replace") as writer:
                df_out.to_excel(writer, sheet_name=sheet_name, index=False)
        else:
            with pd.ExcelWriter(filename, engine="openpyxl", mode="w") as writer:
                df_new.to_excel(writer, sheet_name=sheet_name, index=False)

    _append_row("ACC", row_acc)
    _append_row("ARI", row_ari)
    _append_row("NMI", row_nmi)

def test_first_period_wrapper(args_tuple):
    gnn_first_period, gnn_second_period, class_sizes, snr, gen, logN_div_N, total_ab = args_tuple
    a, b = solve_ab(args.N_train, class_sizes, args.n_classes, snr, total_ab)
    p_SBM = round(a * logN_div_N, 4)
    q_SBM = round(b * logN_div_N, 4)

    gen_local = gen.copy()
    gen_local.p_SBM = p_SBM
    gen_local.q_SBM = q_SBM

    print(f"\n[测试阶段] class_sizes: {class_sizes}, SNR: {snr:.2f}")
    print(f"使用的 SBM 参数: p={p_SBM}, q={q_SBM}")

    row_acc, row_ari, row_nmi = test_first_period(
        gnn_first_period=gnn_first_period,
        gnn_second_period=gnn_second_period,
        n_classes=args.n_classes,
        gen=gen_local,
        args=args,
        iters=args.num_examples_test,
        mode=args.mode_isbalanced,
        class_sizes=class_sizes,
    )

    return row_acc, row_ari, row_nmi

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

if __name__ == '__main__':
    try:
        setup_logger("run_gnn")
        gen = Generator()
        gen.N_train = args.N_train
        gen.N_test = args.N_test
        gen.N_val = args.N_val
        gen.edge_density = args.edge_density
        gen.p_SBM = args.p_SBM
        gen.q_SBM = args.q_SBM
        gen.random_noise = args.random_noise
        gen.noise = args.noise
        gen.noise_model = args.noise_model
        gen.generative_model = args.generative_model
        gen.n_classes = args.n_classes
        gen.num_examples_train = args.num_examples_train
        gen.num_examples_test = args.num_examples_test
        gen.num_examples_val = args.num_examples_val

        root_model_dir = "model_GNN"
        os.makedirs(root_model_dir, exist_ok=True)
        folder_name = f"GNN_model_first_classes{args.n_classes}"
        full_save_dir = os.path.join(root_model_dir, folder_name)
        os.makedirs(full_save_dir, exist_ok=True)

        filename_first = f'gnn_J{args.J}_lyr{args.num_layers}_classes{args.n_classes}_numfeatures{args.num_features}'
        path_first = os.path.join(full_save_dir, filename_first)

        filename_second = f'local+refin_gnn_J{args.J_second}_lyr{args.num_layers_second}_classes{args.n_classes}_numfeatures{args.num_features_second}'
        path_second = os.path.join(full_save_dir, filename_second)

        if args.mode == "train":
            gen.prepare_data()

            num_workers = min(4, multiprocessing.cpu_count() - 1)
            print("num_workers", num_workers)

            train_dataset = SBMDataset(gen.data_train)
            val_dataset = SBMDataset(gen.data_val)
            test_dataset = SBMDataset(gen.data_test)

            train_loader = DataLoader(
                train_dataset, batch_size=args.batch_size, shuffle=True,
                num_workers=num_workers, collate_fn=simple_collate_fn
            )
            val_loader = DataLoader(
                val_dataset, batch_size=args.batch_size, shuffle=False,
                num_workers=num_workers, collate_fn=simple_collate_fn
            )
            test_loader = DataLoader(
                test_dataset, batch_size=args.batch_size, shuffle=False,
                num_workers=num_workers, collate_fn=simple_collate_fn
            )

            print(f"[Phase 1] Train the GNN：n_classes={args.n_classes}, layers={args.num_layers}, J={args.J}, num_features={args.num_features}")
            if os.path.exists(path_first):
                print(f"[Stage 1] Detect that there is already an 'optimal weight' and load it directly: {path_first}")
                if args.generative_model == 'SBM_multiclass':
                    gnn_first_period = GNN_multiclass(args.num_features, args.num_layers, args.J + 3, n_classes=args.n_classes)
                gnn_first_period = gnn_first_period.to(device)
                gnn_first_period = load_best_model_into(gnn_first_period, path_first, device)
            else:
                if args.generative_model == 'SBM_multiclass':
                    gnn_first_period = GNN_multiclass(args.num_features, args.num_layers, args.J + 3, n_classes=args.n_classes)
                if torch.cuda.is_available():
                    gnn_first_period = gnn_first_period.to(device)

                loss_list, acc_list = train_first_period_with_early_stopping(
                    gnn_first_period, train_loader, val_loader, args.n_classes, args,
                    epochs=100, patience=5, save_path=path_first, filename=filename_first
                )
                print(f"[Phase 1] Backload from optimal weight to memory: {path_first}")
                gnn_first_period = load_best_model_into(gnn_first_period, path_first, device)

            maybe_freeze(gnn_first_period, freeze=True)

            print(f"[Stage 2] Training GNN：n_classes={args.n_classes}, layers={args.num_layers_second}, J={args.J_second}, num_features={args.num_features_second}")
            if os.path.exists(path_second):
                print(f"[Stage 2] Detect the existing model file and load it directly: {path_second}")
                gnn_second_period = torch.load(path_second, map_location=device)
                gnn_second_period = gnn_second_period.to(device)
            else:
                if args.generative_model == 'SBM_multiclass':
                    gnn_second_period = GNN_multiclass_second_period(
                        args.num_features_second, args.num_layers_second, args.J_second + 3, n_classes=args.n_classes
                    )
                if torch.cuda.is_available():
                    gnn_second_period = gnn_second_period.to(device)

                loss_list, acc_list = train_second_period_with_early_stopping(
                    gnn_first_period, gnn_second_period, train_loader, val_loader,
                    args.n_classes, args,
                    epochs=100, patience=5, save_path=path_second, filename=filename_second
                )
                print(f"[Phase 2] Save the model to {path_second}")

    except Exception as e:
        import traceback
        traceback.print_exc()
