import numpy as np
import torch
from sklearn.metrics import roc_auc_score, accuracy_score, average_precision_score, recall_score

def test_all(base_model, booster2, stage3, g_o, g_sim, leaf_nb_overlap_tensor, test_loader, device,
             eta1=0.5, eta2=0.5, log1p=True):
    g_o = g_o.to(device)
    g_sim = g_sim.to(device)
    if 'w' in g_sim.edata:
        g_sim.edata['w'] = g_sim.edata['w'].to(device)

    binned = g_o.ndata['binned_feature'].to(device)
    labels_t = g_o.ndata['label'].to(device)

    base_model.eval(); booster2.eval(); stage3.eval()

    feat_nb = leaf_nb_overlap_tensor.clone()
    if log1p:
        feat_nb = torch.log1p(feat_nb)
    feat_nb = feat_nb.to(device)

    with torch.no_grad():
        base_logits_all, _ = base_model(g_o, binned)                # [N,1]
        boost2_logits_all, _ = booster2(g_sim, binned)              # [N,1]
        base12_logits_all = base_logits_all + eta1 * boost2_logits_all

        logits3_all = stage3(feat_nb)                               # [N,1]
        final_logits_all = base12_logits_all + eta2 * logits3_all

        probs_base = torch.sigmoid(base_logits_all)[:, 0].cpu().numpy()
        probs_base12 = torch.sigmoid(base12_logits_all)[:, 0].cpu().numpy()
        probs_final = torch.sigmoid(final_logits_all)[:, 0].cpu().numpy()

    idx_list, y_list = [], []
    p_b_list, p_b12_list, p_fin_list = [], [], []

    for batch_indices, _ in test_loader:
        b = batch_indices.numpy()
        idx_list.extend(b)
        y_list.extend(labels_t[b].cpu().numpy())
        p_b_list.extend(probs_base[b])
        p_b12_list.extend(probs_base12[b])
        p_fin_list.extend(probs_final[b])

    def eval_and_print(tag, y_true, y_prob):
        y_true = np.asarray(y_true).astype(int)
        y_pred = (np.asarray(y_prob) >= 0.5).astype(int)
        try:
            auc = roc_auc_score(y_true, y_prob)
        except ValueError:
            auc = 0.0
        acc = accuracy_score(y_true, y_pred)
        auprc = average_precision_score(y_true, y_prob)
        K = int(np.sum(y_true))
        sorted_idx = np.argsort(y_prob)[::-1]
        top_k_labels = y_true[sorted_idx[:K]] if K > 0 else np.array([])
        rec_at_k = (np.sum(top_k_labels == 1) / K) if K > 0 else 0.0
        per_class_acc = recall_score(y_true, y_pred, average=None, labels=[0, 1])
        acc_neg, acc_pos = per_class_acc[0], per_class_acc[1]
        print(f"[{tag}] neg(0)召回: {acc_neg:.4f}  pos(1)召回: {acc_pos:.4f} | "
              f"AUC: {auc:.4f}  ACC: {acc:.4f}  AUPRC: {auprc:.4f}  rec@K: {rec_at_k:.4f}  K={K}")

    print("==== Test Metrics ====")
    eval_and_print("Base+S2+S3 (Final)", y_list, p_fin_list)
