import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_auc_score, accuracy_score, average_precision_score, recall_score

def compute_boost_weights(base_probs, labels, mode='miscls', weight_pos=3.0, weight_neg=1.5, gamma=2.0):
    y = labels.astype(int)
    p = base_probs
    pred = (p >= 0.5).astype(int)

    w = np.ones_like(p, dtype=np.float32)
    if mode in ['miscls', 'both']:
        miss_pos = (y == 1) & (pred == 0)   # 漏判异常
        miss_neg = (y == 0) & (pred == 1)   # 误判正常
        w = w + weight_pos * miss_pos.astype(np.float32) + weight_neg * miss_neg.astype(np.float32)
    if mode in ['focal', 'both']:
        err = np.abs(y - p)
        w = w * (np.power(err, gamma) + 1e-6)
    return w

def train_booster(booster, base_model, g_o, g_sim,
                  train_loader, val_loader, epochs, device,
                  eta=0.5, boost_mode='miscls', boost_weight_pos=3.0, boost_weight_neg=1.5, boost_gamma=2.0):
    for p in base_model.parameters():
        p.requires_grad = False
    base_model.eval()

    criterion = nn.BCEWithLogitsLoss(reduction='none')
    optimizer = optim.Adam(booster.parameters(), lr=0.01)

    g_o = g_o.to(device)
    g_sim = g_sim.to(device)
    if 'w' in g_sim.edata:
        g_sim.edata['w'] = g_sim.edata['w'].to(device)

    binned = g_o.ndata['binned_feature'].to(device)
    labels_t = g_o.ndata['label'].to(device)

    with torch.no_grad():
        base_logits_all, _ = base_model(g_o, binned)       # [N,1]
        base_probs_all = torch.sigmoid(base_logits_all)[:, 0].cpu().numpy()

    best_acc = 0.0
    best_state = None

    for epoch in range(epochs):
        booster.train()
        optimizer.zero_grad()

        boost_logits_all, _ = booster(g_sim, binned)       # [N,1]
        final_logits_all = base_logits_all.detach() + eta * boost_logits_all

        tr_idx = train_loader.dataset.indices
        train_logits = final_logits_all[tr_idx]
        train_labels = labels_t[tr_idx].unsqueeze(1)

        weights_np = compute_boost_weights(
            base_probs=base_probs_all[tr_idx.cpu().numpy()],
            labels=labels_t[tr_idx].cpu().numpy(),
            mode=boost_mode, weight_pos=boost_weight_pos, weight_neg=boost_weight_neg, gamma=boost_gamma
        )
        weights = torch.tensor(weights_np, dtype=torch.float32, device=device).unsqueeze(1)

        loss_vec = nn.BCEWithLogitsLoss(reduction='none')(train_logits, train_labels)
        loss = (weights * loss_vec).mean()
        loss.backward()
        optimizer.step()

        # 验证
        booster.eval()
        with torch.no_grad():
            boost_logits_all, _ = booster(g_sim, binned)
            final_logits_all = base_logits_all + eta * boost_logits_all
            final_probs_all = torch.sigmoid(final_logits_all)[:, 0].cpu().numpy()

        val_idx_np = val_loader.dataset.indices.cpu().numpy()
        y_val = labels_t[val_idx_np].cpu().numpy()
        p_val = final_probs_all[val_idx_np]

        try:
            auc = roc_auc_score(y_val, p_val)
        except ValueError:
            auc = 0.0
        auprc = average_precision_score(y_val, p_val)
        acc = accuracy_score(y_val, (p_val >= 0.5).astype(int))

        K = int(np.sum(y_val))
        sorted_indices = np.argsort(p_val)[::-1]
        top_k_labels = y_val[sorted_indices[:K]] if K > 0 else np.array([])
        rec_at_k = (np.sum(top_k_labels == 1) / K) if K > 0 else 0.0

        per_class_acc = recall_score(y_val, (p_val >= 0.5).astype(int), average=None, labels=[0, 1])
        acc_neg, acc_pos = per_class_acc[0], per_class_acc[1]
        print(f"[BOOST-2] Epoch {epoch+1}/{epochs} - Loss {loss.item():.4f} | Val AUC {auc:.4f} ACC {acc:.4f} "
              f"AUPRC {auprc:.4f} rec@K {rec_at_k:.4f} | neg_recall {acc_neg:.4f} pos_recall {acc_pos:.4f}")

        if acc > best_acc:
            best_acc = acc
            best_state = booster.state_dict()

    if best_state is not None:
        booster.load_state_dict(best_state)

    return booster, base_logits_all.detach()

def logits_after_stage12(base_model, booster, g_o, g_sim, device, eta1=0.5):
    """ 生成 Stage-1+2 的总 logits（不参与梯度） """
    g_o = g_o.to(device)
    g_sim = g_sim.to(device)
    if 'w' in g_sim.edata:
        g_sim.edata['w'] = g_sim.edata['w'].to(device)
    binned = g_o.ndata['binned_feature'].to(device)
    with torch.no_grad():
        base_logits, _ = base_model(g_o, binned)
        boost2_logits, _ = booster(g_sim, binned)
        final12 = base_logits + eta1 * boost2_logits
    return final12.detach()  # [N,1]
