import time

import torch
import torch.nn as nn
from tqdm import tqdm

from load import get_gnn_inputs
from losses import compute_loss_multiclass, compute_accuracy_multiclass, compute_nmi_multiclass

template_header = '{:<6} {:<10} {:<10} {:<10}'
template_row = '{:<6d} {:<10.4f} {:<10.2f} {:<10.2f}'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cached_graphs = []
cached_labels = []


##Define the train function we need to train the first-period GNN function
def train_batch_first_period(gnn, optimizer, batch, n_classes, iter, device, args):
    """
    使用 batched 输入训练 GNN，适配用户自定义的 permutation-aware 损失函数。
    """
    gnn.train()
    Ws = batch['adj'].to(device)  # shape: (B, N, N)
    labels = batch['labels'].to(device)  # shape: (B, N)

    start = time.time()

    # ✅ 调用 batched GNN 输入处理
    WW, x = get_gnn_inputs(Ws.cpu().numpy(), args.J)  # 输出：WW: (B, N, N, J+3), x: (B, N, d)
    WW = WW.clone().detach().to(torch.float32).to(device)
    x = x.clone().detach().to(torch.float32).to(device)

    optimizer.zero_grad(set_to_none=True)

    # ✅ 前向传播，输出 shape: (B, N, n_classes)
    pred = gnn(WW, x)

    # ✅ 使用你自己的 permutation-aware loss（已内部处理 batch）
    loss = compute_loss_multiclass(pred, labels, n_classes)  # 无需 reshape
    loss.backward()

    # total_norm = torch.norm(torch.stack([
    #     p.grad.detach().data.norm(2)
    #     for p in gnn.parameters() if p.grad is not None
    # ]), 2).item()
    # print(f"梯度范数 = {total_norm:.4f}")

    # ✅ 梯度裁剪 + 参数更新
    nn.utils.clip_grad_norm_(gnn.parameters(), args.clip_grad_norm)
    optimizer.step()

    # ✅ 使用你自己的 accuracy 函数
    acc, _ = compute_accuracy_multiclass(pred, labels, n_classes)

    elapsed_time = time.time() - start
    loss_value = loss.item()

    # ✅ 打印信息
    print(template_header.format(*['iter', 'avg loss', 'avg acc', 'elapsed']))
    print(template_row.format(iter, loss_value, acc, elapsed_time))

    return loss_value, acc


def evaluate_on_loader(gnn, val_loader, n_classes, args, device):
    gnn.train()
    total_loss, total_nmi = 0, 0
    with torch.no_grad():
        for batch in val_loader:
            Ws = batch['adj'].to(device)
            labels = batch['labels'].to(device)

            WW, x = get_gnn_inputs(Ws.cpu().numpy(), args.J)
            WW = WW.clone().detach().to(torch.float32).to(device)
            x = x.clone().detach().to(torch.float32).to(device)

            pred = gnn(WW, x)
            loss = compute_loss_multiclass(pred, labels, n_classes)
            nmi_mean, nmi_list = compute_nmi_multiclass(pred, labels)

            total_loss += loss.item()
            total_nmi += nmi_mean

    avg_loss = total_loss / len(val_loader)
    avg_nmi= total_nmi / len(val_loader)
    return avg_loss, avg_nmi

def train_first_period_with_early_stopping(
    gnn,
    train_loader,
    val_loader,
    n_classes,
    args,
    epochs: int = 100,
    patience: int = 6,
    save_path: str = 'best_model.pt',
    filename: str = "filename_first",
    acc_eps: float = 1e-8,
    loss_eps: float = 1e-12,
):
    """
    早停策略：优先比较 val_acc；若 val_acc 基本相同（|Δacc|<=acc_eps），则比较 val_loss（更小者更优）。
    """
    gnn.train()
    optimizer = torch.optim.Adamax(gnn.parameters(), lr=args.lr)

    loss_lst, acc_lst = [], []
    best_val_nmi = -1.0
    best_val_loss = float("inf")
    patience_counter = 0

    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")
        gnn.train()

        for iter_idx, batch in enumerate(tqdm(train_loader)):
            loss, acc = train_batch_first_period(
                gnn=gnn,
                optimizer=optimizer,
                batch=batch,
                n_classes=n_classes,
                iter=iter_idx,
                device=device,
                args=args
            )
            loss_lst.append(loss)
            acc_lst.append(acc)

            torch.cuda.empty_cache()

        # 🧪 验证集评估
        val_loss, val_nmi = evaluate_on_loader(
            gnn, val_loader, n_classes, args, device=device
        )
        print(f"Validation Loss: {val_loss:.6f}, NMI: {val_nmi:.6f}")

        # 中间快照（与你原来一致）
        torch.save(gnn.cpu(), filename)
        if torch.cuda.is_available():
            gnn = gnn.to(device)

        # ✅ 刷新最佳：先比 acc；若 acc 打平，再比 loss
        improved = False
        if val_nmi > best_val_nmi + acc_eps:
            reason = "val_nmi improved"
            improved = True
        elif abs(val_nmi - best_val_nmi) <= acc_eps and val_loss < best_val_loss - loss_eps:
            reason = "val_nmi tie, val_loss improved"
            improved = True
        else:
            reason = None

        if improved:
            best_val_nmi = val_nmi
            best_val_loss = val_loss
            patience_counter = 0

            # 保存最佳模型（与你原来一致）
            torch.save(gnn.cpu(), save_path)
            # ✅ 可选：更安全的保存方式（推荐）
            # torch.save(gnn.state_dict(), save_path)

            print(f"New best model saved ({reason}). best_acc={best_val_nmi:.6f}, best_loss={best_val_loss:.6f}")
            if torch.cuda.is_available():
                gnn = gnn.to(device)
        else:
            patience_counter += 1
            print(f"No improvement ({patience_counter}/{patience}). "
                  f"best_nmi={best_val_nmi:.6f}, best_loss={best_val_loss:.6f}")

        # ⛔ 提前停止
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

        torch.cuda.empty_cache()  # 可选：按 epoch 清一次

    return loss_lst, acc_lst
