import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_auc_score, accuracy_score, average_precision_score, recall_score

def compute_stage3_weights(p_base12, labels, mode='fn', weight_fn=5.0, weight_fp=1.0, gamma=0.0):
    """
    p_base12: numpy [N], labels: numpy [N] in {0,1}
    mode: 'fn' 仅增强漏检异常；'both' 同时考虑 FP
    """
    y = labels.astype(int)
    p = p_base12
    pred = (p >= 0.5).astype(int)
    w = np.ones_like(p, dtype=np.float32)

    if mode in ['fn', 'both']:
        fn_mask = (y == 1) & (pred == 0)
        w = w + weight_fn * fn_mask.astype(np.float32)
    if mode == 'both':
        fp_mask = (y == 0) & (pred == 1)
        w = w + weight_fp * fp_mask.astype(np.float32)
    if gamma > 0:
        err = np.abs(y - p)
        w = w * (np.power(err, gamma) + 1e-6)
    return w

def train_stage3(stage3, g_o, leaf_nb_overlap_tensor, train_loader, val_loader, epochs, device,
                 base12_logits_fixed, eta2=0.5, focus_mode='fn', weight_fn=5.0, weight_fp=1.0, gamma=0.0, log1p=True):
    """
    仅训练 Stage-3，使 logits_final = base12_logits_fixed + eta2 * logits_stage3
    leaf_nb_overlap_tensor: FloatTensor [N, T]
    """
    criterion = nn.BCEWithLogitsLoss(reduction='none')
    optimizer = optim.Adam(stage3.parameters(), lr=0.01)

    g_o = g_o.to(device)
    labels_t = g_o.ndata['label'].to(device)

    feat_nb = leaf_nb_overlap_tensor.clone()
    if log1p:
        feat_nb = torch.log1p(feat_nb)  # 缩放长尾
    feat_nb = feat_nb.to(device)  # [N,T]

    base12_logits_fixed = base12_logits_fixed.to(device).detach()
    base12_probs = torch.sigmoid(base12_logits_fixed)[:, 0].cpu().numpy()

    best_acc = 0.0
    best_state = None

    for epoch in range(epochs):
        stage3.train()
        optimizer.zero_grad()

        logits3_all = stage3(feat_nb)         # [N,1]
        final_logits_all = base12_logits_fixed + eta2 * logits3_all

        tr_idx = train_loader.dataset.indices
        train_logits = final_logits_all[tr_idx]
        train_labels = labels_t[tr_idx].unsqueeze(1)

        weights_np = compute_stage3_weights(
            p_base12=base12_probs[tr_idx.cpu().numpy()],
            labels=labels_t[tr_idx].cpu().numpy(),
            mode=focus_mode, weight_fn=weight_fn, weight_fp=weight_fp, gamma=gamma
        )
        weights = torch.tensor(weights_np, dtype=torch.float32, device=device).unsqueeze(1)

        loss_vec = criterion(train_logits, train_labels)
        loss = (weights * loss_vec).mean()
        loss.backward()
        optimizer.step()

        # 验证
        stage3.eval()
        with torch.no_grad():
            logits3_all = stage3(feat_nb)
            final_logits_all = base12_logits_fixed + eta2 * logits3_all
            final_probs_all = torch.sigmoid(final_logits_all)[:, 0].cpu().numpy()

        val_idx_np = val_loader.dataset.indices.cpu().numpy()
        y_val = labels_t[val_idx_np].cpu().numpy()
        p_val = final_probs_all[val_idx_np]

        try:
            auc = roc_auc_score(y_val, p_val)
        except ValueError:
            auc = 0.0
        auprc = average_precision_score(y_val, p_val)
        acc = accuracy_score(y_val, (p_val >= 0.5).astype(int))

        K = int(np.sum(y_val))
        sorted_indices = np.argsort(p_val)[::-1]
        top_k_labels = y_val[sorted_indices[:K]] if K > 0 else np.array([])
        rec_at_k = (np.sum(top_k_labels == 1) / K) if K > 0 else 0.0

        per_class_acc = recall_score(y_val, (p_val >= 0.5).astype(int), average=None, labels=[0, 1])
        acc_neg, acc_pos = per_class_acc[0], per_class_acc[1]
        print(f"[BOOST-3] Epoch {epoch+1}/{epochs} - Loss {loss.item():.4f} | Val AUC {auc:.4f} ACC {acc:.4f} "
              f"AUPRC {auprc:.4f} rec@K {rec_at_k:.4f} | neg_recall {acc_neg:.4f} pos_recall {acc_pos:.4f}")

        if acc > best_acc:
            best_acc = acc
            best_state = stage3.state_dict()

    if best_state is not None:
        stage3.load_state_dict(best_state)
    return stage3
