import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_auc_score, accuracy_score, average_precision_score, recall_score

def train_base(model, g, train_loader, val_loader, epochs, device):
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)

    best_acc = 0.0
    best_state = None

    g = g.to(device)
    binned = g.ndata['binned_feature'].to(device)
    labels = g.ndata['label'].to(device)

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        logits, _ = model(g, binned)
        tr_idx = train_loader.dataset.indices
        loss = criterion(logits[tr_idx], labels[tr_idx].unsqueeze(1))
        loss.backward()
        optimizer.step()

        # 验证
        model.eval()
        with torch.no_grad():
            logits, _ = model(g, binned)
            probs = torch.sigmoid(logits)[:, 0].cpu().numpy()
        val_idx_np = val_loader.dataset.indices.cpu().numpy()
        y_val = labels[val_idx_np].cpu().numpy()
        p_val = probs[val_idx_np]

        try:
            auc = roc_auc_score(y_val, p_val)
        except ValueError:
            auc = 0.0
        auprc = average_precision_score(y_val, p_val)
        acc = accuracy_score(y_val, (p_val >= 0.5).astype(int))

        K = int(np.sum(y_val))
        sorted_indices = np.argsort(p_val)[::-1]
        top_k_labels = y_val[sorted_indices[:K]] if K > 0 else np.array([])
        rec_at_k = (np.sum(top_k_labels == 1) / K) if K > 0 else 0.0

        per_class_acc = recall_score(y_val, (p_val >= 0.5).astype(int), average=None, labels=[0, 1])
        acc_neg, acc_pos = per_class_acc[0], per_class_acc[1]
        print(f"[BASE] Epoch {epoch+1}/{epochs} - Loss {loss.item():.4f} | Val AUC {auc:.4f} ACC {acc:.4f} "
              f"AUPRC {auprc:.4f} rec@K {rec_at_k:.4f} | neg_recall {acc_neg:.4f} pos_recall {acc_pos:.4f}")

        if acc > best_acc:
            best_acc = acc
            best_state = model.state_dict()

    if best_state is not None:
        model.load_state_dict(best_state)
    return model
