import os, json, math, argparse, random, time, sys
from typing import List, Dict, Any, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import wandb
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, roc_auc_score, silhouette_score, confusion_matrix
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import Sampler
import random

# ====================== Utils (无变动) ======================
def set_seed(seed: int):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    print(f"[Seed] {seed}")

def device_of():
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

def mean_pooling(model_output, attention_mask):
    last_hidden_state = model_output[0]
    mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
    return (last_hidden_state * mask).sum(dim=1) / torch.clamp(mask.sum(dim=1), min=1e-9)

class BalancedBatchSampler(Sampler):
    def __init__(self, labels, batch_size: int, min_per_class: int = 2, drop_last: bool = True):
        super().__init__(None)
        self.labels = np.asarray(labels)
        self.classes = np.unique(self.labels)
        self.num_classes = len(self.classes)
        assert batch_size % self.num_classes == 0, \
            f"batch_size({batch_size}) 必须能被类别数({self.num_classes})整除"
        self.per = batch_size // self.num_classes
        assert self.per >= min_per_class, \
            f"每类采样数({self.per}) < 要求下限({min_per_class})，请减小 batch_size 或类别数"
        # 各类索引池
        self.idxs = {c: np.where(self.labels == c)[0] for c in self.classes}
        self.drop_last = drop_last
        self._reset_epoch_state()

        # 能生成的 batch 数（按各类地板对齐）
        self.num_batches = min(len(self.idxs[c]) // self.per for c in self.classes)
        if self.num_batches == 0:
            raise ValueError(
                f"任一类别样本数不足以组成一个 batch（每类至少 {self.per}）。"
                f"请减小 batch_size 或放宽策略。"
            )

    def _reset_epoch_state(self):
        self.ptr = {}
        for c in self.classes:
            arr = self.idxs[c].copy()
            np.random.shuffle(arr)
            self.idxs[c] = arr
            self.ptr[c] = 0

    def __len__(self):
        return self.num_batches

    def __iter__(self):
        self._reset_epoch_state()
        for _ in range(self.num_batches):
            batch = []
            for c in self.classes:
                s = self.ptr[c]; e = s + self.per
                batch.extend(self.idxs[c][s:e].tolist())
                self.ptr[c] = e
            random.shuffle(batch)
            # 防御式断言，确保绝不为空
            if not batch:
                raise RuntimeError("BalancedBatchSampler 生成了空 batch（不应发生）。")
            yield batch

# ================== 损失函数 (新) ==================
class SupervisedContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super().__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        device = features.device
        batch_size = features.shape[0]
        
        # 构造标签掩码
        labels = labels.contiguous().view(-1, 1)
        # same_class_mask[i, j] is True if labels[i] == labels[j]
        same_class_mask = torch.eq(labels, labels.T).float().to(device)
        
        # 构造对比掩码
        # cosine similarity matrix
        diag_mask = torch.eye(batch_size, device=device, dtype=torch.bool)
        sim_matrix = F.cosine_similarity(features.unsqueeze(1), features.unsqueeze(0), dim=2)
        sim_matrix = sim_matrix / self.temperature
        very_neg = torch.finfo(sim_matrix.dtype).min
        sim_matrix = sim_matrix.masked_fill(diag_mask, very_neg)
        # 屏蔽掉对角线（自己和自己的相似度）

        same_class_mask = torch.eq(labels, labels.T)    # bool
        N_i = (same_class_mask.sum(dim=1) - 1).clamp(min=1)
        log_prob = sim_matrix - torch.logsumexp(sim_matrix, dim=1, keepdim=True)
        positive_log_prob = (same_class_mask & (~diag_mask)).to(log_prob.dtype) * log_prob
        loss = - (positive_log_prob.sum(dim=1) / N_i.clamp(min=1)).mean()
        
        return loss
# ================== Model & Heads ==================
class ContrastiveModel(nn.Module):
    def __init__(self, model_name: str, projection_dim: int = 256, freeze: bool = False, num_encoder_layers: int = 2, nhead: int = 8):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(model_name)
        # 冻结Backbone，这是核心要求
        if freeze:
            self.backbone.requires_grad_(False)
        else:
            self.backbone.requires_grad_(True)
        # self.backbone.requires_grad_(freeze)
        
        hidden_size = self.backbone.config.hidden_size
        # self.projection_head = nn.Sequential(
        #     nn.Linear(hidden_size, 2*hidden_size),
        #     nn.LayerNorm(2*hidden_size),
        #     nn.GELU(),
        #     nn.Linear(2*hidden_size, projection_dim)
        # )

        self.projection_head = AttnPoolProjector(hidden_size, projection_dim)

    def forward(self, input_ids, attention_mask, **kwargs):
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        backbone_embedding_orin = mean_pooling(out, attention_mask)
        backbone_embedding = F.normalize(backbone_embedding_orin, p=2, dim=1)
        
        projection_embedding = self.projection_head.forward_from_tokens(out.last_hidden_state, attention_mask)
        # projection_embedding = F.normalize(projection_embedding, p=2, dim=1)
        
        return {
            "backbone_embedding": backbone_embedding,
            "projection_embedding": projection_embedding
        }

# 新增：独立的分类头
class ClassificationHead(nn.Module):
    def __init__(self, input_dim: int, num_classes: int = 2):
        super().__init__()
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

class MLPHead(nn.Module):
    def __init__(self, 
                 input_dim: int,  
                 num_classes: int = 2, 
                 dropout_p: float = 0.1):
        super().__init__()
        hidden_dim = 2*input_dim
        self.net = nn.Sequential(
            nn.LayerNorm(input_dim),
            
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(), 
            nn.Dropout(dropout_p), 
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x):
        return self.net(x)

class ResidualMLPHead(nn.Module):
    def __init__(self, in_dim, num_classes: int = 2, width=None, depth=3, p=0.1):
        super().__init__()
        width = width or 2*in_dim
        blocks = []
        cur = in_dim
        for _ in range(depth):
            blocks += [
                nn.LayerNorm(cur),
                nn.Linear(cur, width), nn.GELU(), nn.Dropout(p),
                nn.Linear(width, cur),
            ]
        self.ffn = nn.Sequential(*blocks)
        self.out = nn.Linear(cur, num_classes)
    def forward(self, x):
        y = x
        y = self.ffn(y) + y                
        z = self.out(y)
        return z

class AttnPoolProjector(nn.Module):
    def __init__(self, hidden, proj_dim, n_query=2):
        super().__init__()
        self.query = nn.Parameter(torch.randn(n_query, hidden))
        nn.init.xavier_uniform_(self.query)
        self.out = nn.Linear(hidden*n_query, proj_dim)

    def forward_from_tokens(self, token_h, attn_mask):
        # token_h: [B,L,H]
        q = F.normalize(self.query, dim=-1)                    # [Q,H]
        k = F.normalize(token_h, dim=-1)                       # [B,L,H]
        attn = torch.einsum('qh,blh->bql', q, k) / math.sqrt(token_h.size(-1))
        very_neg = torch.finfo(attn.dtype).min
        attn = attn.masked_fill(~attn_mask[:,None,:].bool(), very_neg)
        w = attn.softmax(dim=-1)                               # [B,Q,L]
        pooled = torch.einsum('bql,blh->bqh', w, token_h).reshape(token_h.size(0), -1)
        z = self.out(pooled)
        return F.normalize(z, p=2, dim=1)

class AsymBCEWithLogits(nn.Module):
    def __init__(self, gamma_pos: float = 0.0, gamma_neg: float = 3.0, fp_cost: float = 1.0):
        super().__init__()
        self.gamma_pos = gamma_pos
        self.gamma_neg = gamma_neg
        self.fp_cost   = fp_cost 

    def forward(self, logits: torch.Tensor, target: torch.Tensor, ce_weights: torch.Tensor | None = None):
        margin = logits[:, 1] - logits[:, 0]       
        p = torch.sigmoid(margin)                   
        y = target.float()

        pt   = torch.where(y == 1, p, 1 - p)      
        gamma= torch.where(y == 1,
                           torch.full_like(pt, self.gamma_pos),
                           torch.full_like(pt, self.gamma_neg))
        focal = (1 - pt).clamp(min=1e-8).pow(gamma)

        bce = F.binary_cross_entropy_with_logits(margin, y, reduction='none')

        if ce_weights is not None:
            w_cls = torch.where(y == 1, ce_weights[1], ce_weights[0])
        else:
            w_cls = 1.0

        w_fp = torch.where(y == 0, 1.0 + (self.fp_cost - 1.0) * p.detach(), 1.0)

        loss = (w_cls * w_fp * focal * bce).mean()
        return loss

@torch.no_grad()
def collect_probs(model, clf_head, loader, device, amp_dtype):
    model.eval(); clf_head.eval()
    y_true, y_prob = [], []
    for enc, y, _ in loader:
        enc = {k: v.to(device, non_blocking=True) for k, v in enc.items()}
        with torch.amp.autocast("cuda", dtype=amp_dtype, enabled=(device.type=="cuda")):
            out = model(**enc)
            z = out["projection_embedding"]
            logits = clf_head(z)
            prob = F.softmax(logits, dim=1)[:, 1]
        y_true.append(y.cpu()); y_prob.append(prob.cpu())
    y_true = torch.cat(y_true).numpy().astype(int)
    y_prob = torch.cat(y_prob).numpy()
    return y_true, y_prob

def find_threshold_for_precision(y_true: np.ndarray, y_score: np.ndarray,
                                 prec_target: float = 0.95) -> tuple[float, dict]:
    order = np.argsort(-y_score)
    y_sorted = y_true[order]
    s_sorted = y_score[order]

    tp = (y_sorted == 1).astype(int)
    fp = (y_sorted == 0).astype(int)
    ctp = np.cumsum(tp)
    cfp = np.cumsum(fp)

    denom = np.maximum(1, ctp + cfp)
    precision = ctp / denom
    recall    = ctp / max(1, int((y_true == 1).sum()))

    idx = np.where(precision >= prec_target)[0]
    if idx.size == 0:
        i = int(np.argmax(precision))
    else:
        i = int(idx[np.argmax(recall[idx])])

    thr = float(s_sorted[i])
    stats = {
        "thr": thr,
        "precision": float(precision[i]),
        "recall":    float(recall[i]),
        "num_selected": int(i+1),
    }
    return thr, stats

@torch.no_grad()
def collect_scores(model, clf_head, loader, device, amp_dtype):
    model.eval(); clf_head.eval()
    ys, scores = [], []
    for enc, y, _ in loader:
        enc = {k: v.to(device, non_blocking=True) for k, v in enc.items()}
        with torch.amp.autocast("cuda", dtype=amp_dtype, enabled=(device.type=="cuda")):
            z = model(**enc)["projection_embedding"]
            logits = clf_head(z)
            margin = (logits[:, 1] - logits[:, 0]).to(torch.float32)
        ys.append(y.cpu())
        scores.append(margin.cpu())
    return torch.cat(ys).numpy(), torch.cat(scores).numpy()

def collect_all_zero_prompt_ids(jsonl_path: str) -> set[int]:
    table = {}
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip(): continue
            ex = json.loads(line)
            pid = int(ex.get("prompt_id", -1))
            if pid < 0: 
                continue 
            lab = int(ex.get("label", 0))
            entry = table.setdefault(pid, {"has1": False, "cnt": 0})
            entry["cnt"]  += 1
            if lab == 1: entry["has1"] = True
    all_zero = {pid for pid, v in table.items() if (v["cnt"] > 0 and not v["has1"])}
    return all_zero


def mark_ambiguous_pos_by_neighbors(
    embeddings: np.ndarray,  # [N,D], L2-normalized
    labels: np.ndarray,      # [N], 0/1
    k: int = 5,
    min_sim: float = 0.0,
) -> np.ndarray:
    N = embeddings.shape[0]
    S = embeddings @ embeddings.T
    np.fill_diagonal(S, -1.0)

    idx_top = np.argpartition(S, -k, axis=1)[:, -k:]
    row = np.arange(N)[:, None]
    sims_top = S[row, idx_top]
    order = np.argsort(-sims_top, axis=1)
    idx_top = idx_top[row, order]
    sims_top = sims_top[row, order]

    need = (k + 1)//2
    y = labels.astype(np.int64)
    amb = np.zeros(N, dtype=bool)
    for i in range(N):
        if y[i] != 1:
            continue
        mask = sims_top[i] >= min_sim
        if not mask.any(): 
            continue
        neigh = idx_top[i][mask]
        cnt0 = int((y[neigh] == 0).sum())
        if cnt0 >= need:
            amb[i] = True
    return amb

# scheduler
class ModelEMA:
    def __init__(self, module: nn.Module, decay: float = 0.999):
        self.module = module
        self.decay = decay
        self.shadow = [p.detach().clone() for p in module.parameters() if p.requires_grad]
        self.params = [p for p in module.parameters() if p.requires_grad]

    @torch.no_grad()
    def update(self):
        for s, p in zip(self.shadow, self.params):
            s.mul_(self.decay).add_(p.detach(), alpha=1.0 - self.decay)

    @torch.no_grad()
    def apply_shadow(self):
        self.backup = [p.detach().clone() for p in self.params]
        for p, s in zip(self.params, self.shadow):
            p.copy_(s)

    @torch.no_grad()
    def restore(self):
        for p, b in zip(self.params, self.backup):
            p.copy_(b)
        del self.backup

# 查看FP
@torch.no_grad()
def collect_scores_and_pids(model, clf_head, loader, device, amp_dtype):
    model.eval(); clf_head.eval()
    ys, scores, pids_all = [], [], []
    for enc, y, pids in loader:               
        enc = {k: v.to(device, non_blocking=True) for k, v in enc.items()}
        with torch.amp.autocast("cuda", dtype=amp_dtype, enabled=(device.type=="cuda")):
            z = model(**enc)["projection_embedding"]
            logits = clf_head(z)

            margin = (logits[:, 1] - logits[:, 0]).to(torch.float32)
        ys.append(y.cpu())
        scores.append(margin.cpu())
        pids_all.append(pids.cpu())
    return (torch.cat(ys).numpy(),
            torch.cat(scores).numpy(),
            torch.cat(pids_all).numpy())

def analyze_fp_difficulty(y_true: np.ndarray,
                          y_score: np.ndarray,
                          prompt_ids: np.ndarray,
                          all_zero_prompts: set[int],
                          out_csv: str | None = None) -> dict:
    y_pred = (y_score >= 0.0).astype(int)

    fp_mask = (y_pred == 1) & (y_true == 0)
    n_fp = int(fp_mask.sum())
    if n_fp == 0:
        stats = {"fp_total": 0, "fp_allzero": 0, "fp_allzero_ratio": 0.0}
        return stats

    fp_pids = prompt_ids[fp_mask]
    is_allzero = np.array([int(pid in all_zero_prompts) for pid in fp_pids], dtype=np.int32)
    n_allzero = int(is_allzero.sum())
    ratio = float(n_allzero / max(1, n_fp))

    stats = {"fp_total": n_fp, "fp_allzero": n_allzero, "fp_allzero_ratio": ratio}

    if out_csv:
        import csv
        with open(out_csv, "w", newline="", encoding="utf-8") as f:
            w = csv.writer(f)
            w.writerow(["idx_in_eval", "prompt_id", "score_margin", "is_allzero"])
            idxs = np.where(fp_mask)[0]
            for i, pid, sc, az in zip(idxs, fp_pids, y_score[fp_mask], is_allzero):
                w.writerow([int(i), int(pid), float(sc), int(az)])

    print(f"[FP] total={n_fp}, all-zero={n_allzero} ({ratio:.3%})")
    return stats

class JsonlRows:
    def __init__(self, path: str, model_id: int | None, skip: set | None):
        self.rows: List[Dict[str,Any]] = []
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                ex = json.loads(line)
                if (model_id is not None) and (int(ex.get("model_id", -1)) != model_id):
                    continue
                if (skip is not None) and ex.get("category", -1) in skip:
                    continue
                self.rows.append({"text": ex["prompt"], "label": int(ex["label"]), "prompt_id": int(ex["prompt_id"]),})
        if not self.rows:
            print(f"[WARN] No rows loaded from {path} with model_id={model_id}")
            return
        pos = sum(r["label"] for r in self.rows); neg = len(self.rows) - pos
        print(f"[Load] {path} -> {len(self.rows)} samples (pos={pos}, neg={neg}, pos_ratio={pos/len(self.rows):.3f})")


class SimpleDataset(Dataset):
    def __init__(self, rows: List[Dict[str, Any]]):
        self.rows = rows
    def __len__(self): return len(self.rows)
    def __getitem__(self, idx):
        return self.rows[idx]["text"], self.rows[idx]["label"], self.rows[idx]["prompt_id"]


def collate_classification(batch, tokenizer, max_len=512):
    texts, labels, pids = zip(*batch)
    enc = tokenizer(list(texts), truncation=True, max_length=max_len, padding=True,
                    return_tensors="pt", return_attention_mask=True)
    return enc, torch.tensor(labels, dtype=torch.long), torch.tensor(pids, dtype=torch.long)



def cosine_sim_np(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    return a @ b.T  # a,b L2-normalized


@torch.no_grad()
def encode_for_eval(model, loader, device, amp_dtype):
    model.eval()
    all_z, all_y = [], []
    for enc, y, _ in loader:
        enc = {k: v.to(device, non_blocking=True) for k, v in enc.items()}
        with torch.amp.autocast("cuda", dtype=amp_dtype, enabled=(device.type=="cuda")):
            outputs = model(**enc)
            all_z.append(outputs["projection_embedding"].cpu())
            all_y.append(y.cpu())
    return torch.cat(all_z).numpy(), torch.cat(all_y).numpy()

def evaluate_separability(model, loader, device, amp_dtype):
    Z, y = encode_for_eval(model, loader, device, amp_dtype)
    if len(np.unique(y)) < 2 or len(Z) < 2:
        return {"silhouette_cosine": -1.0}
    
    try:
        score = silhouette_score(Z, y, metric="cosine")
    except ValueError:
        score = -1.0 
    return {"silhouette_cosine": float(score)}

@torch.no_grad()
def evaluate_classifier(model, clf_head, loader, device, amp_dtype):
    model.eval()
    clf_head.eval()
    all_labels, all_probs = [], []
    all_scores = []

    for enc, y, _ in loader:
        enc = {k: v.to(device, non_blocking=True) for k, v in enc.items()}
        with torch.amp.autocast("cuda", dtype=amp_dtype, enabled=(device.type=="cuda")):
            outputs = model(**enc)
            z = outputs["projection_embedding"]
            logits = clf_head(z)
            # probs = F.softmax(logits, dim=1)[:, 1] # 取正类的概率
            # probs = torch.sigmoid(logits[:, 1]).to(torch.float32)
            s = (logits[:, 1] - logits[:, 0]).to(torch.float32)
        all_labels.append(y.cpu())
        # all_probs.append(probs.cpu())
        all_scores.append(s.cpu())

    y_true = torch.cat(all_labels).numpy()
    # y_prob = torch.cat(all_probs).numpy()
    y_score = torch.cat(all_scores).numpy()

    y_prob = 1.0 / (1.0 + np.exp(-y_score))
    # y_pred = (y_prob >= 0.5).astype(int)
    y_pred = (y_score >= 0.0).astype(int)


    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()

    acc = accuracy_score(y_true, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
    auc = roc_auc_score(y_true, y_prob)


    return {"acc": acc, "f1": f1, "precision": prec, "recall": rec, "auc": auc,
            "TP": int(tp), "FP": int(fp), "TN": int(tn), "FN": int(fn)}

@torch.no_grad()
def batch_cosine_stats(z: torch.Tensor, y: torch.Tensor):

    B = z.size(0)
    S = z @ z.t()                                 
    eye = torch.eye(B, device=z.device, dtype=torch.bool)
    same = (y.view(-1,1) == y.view(1,-1))          
    pos_mask = same & ~eye
    neg_mask = ~same & ~eye

    has_pos = pos_mask.any(dim=1)                   
    frac_has_pos = has_pos.float().mean().item()

    stats = {}
    if pos_mask.any():
        pos_vals = S[pos_mask]
        stats.update({
            "pos_mean": pos_vals.mean().item(),
            "pos_median": pos_vals.median().item(),
        })
    else:
        stats.update({"pos_mean": float("nan"), "pos_median": float("nan")})

    if neg_mask.any():
        neg_vals = S[neg_mask]
        stats.update({
            "neg_mean": neg_vals.mean().item(),
            "neg_median": neg_vals.median().item(),
        })
    else:
        stats.update({"neg_mean": float("nan"), "neg_median": float("nan")})


    pos_cnt = pos_mask.sum(dim=1).float()           
    stats.update({
        "pos_cnt/mean": pos_cnt.mean().item(),
        "pos_cnt/min":  pos_cnt.min().item(),
        "pos_cover": frac_has_pos,                  
    })
    return stats

def log_batch_stats_wandb(stats: dict, prefix="sim/"):
    if wandb.run is None: 
        return
    wandb.log({prefix + k: v for k, v in stats.items()})


def read_jsonl(path: str) -> list[dict]:
    exs = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            s=line.strip()
            if not s: continue
            exs.append(json.loads(s))
    return exs

def write_jsonl(path: str, exs: list[dict]):
    with open(path, "w", encoding="utf-8") as f:
        for ex in exs:
            f.write(json.dumps(ex, ensure_ascii=False) + "\n")

@torch.no_grad()
def encode_texts_projection(model, tokenizer, examples: list[dict], device, batch_size=128, max_len=512, amp_dtype=None,
                            text_key="prompt"):
    model.eval()
    Z = []
    for i in range(0, len(examples), batch_size):
        texts = [ex[text_key] for ex in examples[i:i+batch_size]]
        enc = tokenizer(texts, truncation=True, max_length=max_len, padding=True,
                        return_tensors="pt", return_attention_mask=True)
        enc = {k: v.to(device, non_blocking=True) for k,v in enc.items()}
        with torch.amp.autocast("cuda", dtype=amp_dtype, enabled=(device.type=="cuda" and amp_dtype is not None)):
            out = model(**enc)
            z = out["projection_embedding"]          # 已经是 L2 归一化
        Z.append(z.cpu())
    return torch.cat(Z, dim=0).numpy()               # [N, D], L2-normalized

def neighbor_flip_only_pos_to_neg(
    embeddings: np.ndarray,            # [N, D], L2-normalized
    labels: np.ndarray,                # [N], in {0,1}
    k: int = 5,
    min_sim: float = 0.0,
) -> tuple[np.ndarray, dict]:

    N = embeddings.shape[0]
    S = embeddings @ embeddings.T                               
    np.fill_diagonal(S, -1.0)                                   

    idx_top = np.argpartition(S, -k, axis=1)[:, -k:]            
    row_arange = np.arange(N)[:, None]
    sims_top = S[row_arange, idx_top]                           # [N,k]
    order = np.argsort(-sims_top, axis=1)                       
    idx_top = idx_top[row_arange, order]
    sims_top = sims_top[row_arange, order]

    y = labels.astype(np.int64)
    new_y = y.copy()
    flips = 0
    considered = 0

    need = (k // 2) + (k % 2 == 0)  
    for i in range(N):
        neigh = idx_top[i]
        simv  = sims_top[i]
        mask = (simv >= min_sim)
        if not mask.any():
            continue
        neigh = neigh[mask]
        considered += 1

        yi = y[i]
        if yi != 1:
            # 只允许 1->0
            continue

        opp = 1 - yi
        cnt_opp = int((labels[neigh] == opp).sum())
        if cnt_opp >= need:
            new_y[i] = 0
            flips += 1

    stats = {
        "N": int(N),
        "considered": int(considered),
        "flips_1_to_0": int(flips),
        "pos_before": int((y==1).sum()),
        "pos_after":  int((new_y==1).sum()),
    }
    return new_y, stats



# ================= Main (重大修改) =================
def main():
    ap = argparse.ArgumentParser("Two-stage training: Contrastive Projection + Classification Head")
    # ... 保留原有参数 ...
    ap.add_argument("--model_name", type=str, default="sentence-transformers/all-mpnet-base-v2")
    ap.add_argument("--train_path", type=str, required=True)
    ap.add_argument("--val_path",   type=str, required=True)
    ap.add_argument("--model_id",   type=int, default=0)
    ap.add_argument("--max_len",    type=int, default=512)
    ap.add_argument("--batch_size", type=int, default=64)
    ap.add_argument("--lr",         type=float, default=2e-5, help="LR for projection head")
    ap.add_argument("--projection_dim", type=int, default=384)
    ap.add_argument("--lambda_triplet", type=float, default=1.0)
    ap.add_argument("--lambda_nce",     type=float, default=1.0)
    ap.add_argument("--margin",     type=float, default=0.3)
    ap.add_argument("--topk_build", type=int, default=10, help="k for hard triplet mining")
    ap.add_argument("--bf16",       type=lambda x: str(x).lower()=="true", default=True)
    ap.add_argument("--seed",       type=int, default=42)
    ap.add_argument("--output_dir", type=str, default="./two_stage_ckpts")
    
    # 新增分类阶段参数
    ap.add_argument("--epochs_contrastive", type=int, default=15, help="Epochs for Stage 1")
    ap.add_argument("--epochs_classifier",  type=int, default=10, help="Epochs for Stage 2")
    ap.add_argument("--lr_classifier",      type=float, default=1e-3, help="LR for classification head")

    ap.add_argument("--supcon_temp", type=float, default=0.1, help="Temperature for SupCon loss")
    ap.add_argument("--freeze", action="store_true")
    
    ap.add_argument("--pre_flip_enable", type=lambda x: str(x).lower()=="true", default=False,
                    help="是否启用基于近邻一致性的 1->0 标签翻转预处理")
    ap.add_argument("--pre_flip_k", type=int, default=5,
                    help="Top-K 近邻（3 或 5）")
    ap.add_argument("--pre_flip_min_sim", type=float, default=0.0,
                    help="近邻最小相似度阈值（低于此相似度的近邻不计入表决）")
    ap.add_argument("--pre_flip_out", type=str, default=None,
                    help="翻转后训练集的输出 jsonl 路径（默认 train_path + .flip.jsonl）")
    ap.add_argument("--pre_flip_proj_ckpt", type=str, default=None,
                    help="用于编码的投影头权重（如 best_projection.pt），必须已训练完成")
    
    # precision筛选与loss选择
    ap.add_argument("--loss_type", choices=["ce", "asym_bce"], default="asym_bce")
    ap.add_argument("--gamma_pos", type=float, default=0.0)
    ap.add_argument("--gamma_neg", type=float, default=3.0)   # 建议 2~5
    ap.add_argument("--fp_cost",   type=float, default=1.0)   # >1 更惩罚 FP
    ap.add_argument("--prec_target", type=float, default=0.8,
                    help="验证集上阈值选择时的目标精确率（precision ≥ 该值）")
    ap.add_argument("--save_best_by", choices=["recall","f1"], default="f1")

    ap.add_argument("--smooth_enable", type=lambda x: str(x).lower()=="true", default=False,
                help="启用基于全模型一致性 + 邻居一致性的 label smoothing")
    ap.add_argument("--smooth_all_data_path", type=str, default=None,
                    help="用于统计‘所有模型均为0’的jsonl（需包含 prompt_id、model_id、label）。不填则尝试用 --train_path")
    ap.add_argument("--smooth_proj_ckpt", type=str, default=None,
                    help="编码文本的 best_projection.pt（必须已训练完成）")
    ap.add_argument("--smooth_k", type=int, default=5, choices=[3,5], help="KNN的k")
    ap.add_argument("--smooth_min_sim", type=float, default=0.0, help="KNN最小相似度阈值")
    ap.add_argument("--smooth_pos_low", type=float, default=0.3, help="疑难正样本的下界(含)")
    ap.add_argument("--smooth_pos_high", type=float, default=0.5, help="疑难正样本的上界(含)")
    ap.add_argument("--smooth_out", type=str, default=None, help="平滑后训练集输出路径(默认 train_path + .smooth.jsonl)")

    # scheduler 
    ap.add_argument("--scheduler", type=str, default="cosine",
                    choices=["none","onecycle","cosine","plateau","linear_warmup"])
    ap.add_argument("--warmup_ratio", type=float, default=0.1)
    ap.add_argument("--min_lr", type=float, default=1e-6)
    ap.add_argument("--max_grad_norm", type=float, default=1.0)
    ap.add_argument("--early_stop_patience", type=int, default=5)
    ap.add_argument("--ema_decay", type=float, default=0.0, help="0=off, e.g. 0.999")

    ap.add_argument("--hardflip_enable", type=lambda x: str(x).lower()=="true", default=False,
                    help="启用困难样本探测器训练：将 all_zero prompts 标为1，其余为0")
    ap.add_argument("--hardflip_all_data_path", type=str, default=None,
                    help="统计 all_zero 的多模型全量数据 jsonl（需含 prompt_id, model_id, label）")
    ap.add_argument("--hardflip_out", type=str, default=None,
                    help="输出 hard/not-hard 训练集（默认 train_path + .hard.jsonl）")
    ap.add_argument("--hardflip_label_key", type=str, default="label_hard",
                    help="新标签字段名，默认 label_hard")
    ap.add_argument("--fp_allzero_ref_path", type=str, default=None,
                help="评估FP时统计all-zero所用的多模型jsonl（应为验证/评估集，需含 prompt_id, model_id, label）。缺省则使用 --val_path。")


    args = ap.parse_args()
    # ... (setup: seed, device, dir, wandb) ...
    set_seed(args.seed)
    device = device_of()
    amp_dtype = torch.bfloat16 if args.bf16 else torch.float16
    os.makedirs(args.output_dir, exist_ok=True)
    wandb.init(project="emb-contrastive-clf", name=f"two-stage-{os.path.basename(args.model_name)}", config=vars(args))

    # --- 数据加载 ---
    tok = AutoTokenizer.from_pretrained(args.model_name)
    if tok.pad_token is None: tok.pad_token = tok.eos_token
    
    import pandas
    sort_cat_path = "./sorted_categories.csv"
    df_skip = pandas.read_csv(sort_cat_path)
    skip = set(df_skip['category'].head(20))

    tr_rows = JsonlRows(args.train_path, args.model_id, None).rows
    va_rows = JsonlRows(args.val_path,   args.model_id, None).rows
    if not tr_rows or not va_rows:
        sys.exit("Error: No training or validation data loaded.")

    # --- 模型初始化 ---
    model = ContrastiveModel(
        model_name=args.model_name, 
        projection_dim=args.projection_dim,
        freeze=args.freeze
    ).to(device)

    # # =================================================================
    # # == STAGE 1: Train Projection Head with Contrastive Loss         ==
    # # =================================================================
    # print("\n" + "="*60)
    # print("== STAGE 1: Training Projection Head via Contrastive Learning ==")
    # print("="*60)

    # # 优化器只优化 projection_head 的参数
    # optim_proj = torch.optim.AdamW(model.projection_head.parameters(), lr=args.lr, weight_decay=0.01)
    
    # train_ds_simple = SimpleDataset(tr_rows)
    # # 确保 shuffle=True 和 drop_last=True 以获得好的 batch
    # labels_for_sampler = [r["label"] for r in tr_rows]
    # sampler = BalancedBatchSampler(labels_for_sampler, batch_size=args.batch_size, min_per_class=2)
    
    # # train_dl_simple = DataLoader(train_ds_simple, batch_size=args.batch_size, shuffle=True, num_workers=2,
    # #                              collate_fn=lambda b: collate_classification(b, tok, args.max_len), drop_last=True)
    # train_dl_simple = DataLoader(
    #     train_ds_simple,
    #     batch_sampler=sampler,                         # ← 用 batch_sampler
    #     num_workers=8,
    #     pin_memory=(device.type == "cuda"),
    #     persistent_workers=True,
    #     collate_fn=lambda b: collate_classification(b, tok, args.max_len)
    # )
    
    # 评估用 DataLoader (Simple)
    val_ds_simple = SimpleDataset(va_rows)
    val_dl_simple = DataLoader(val_ds_simple, batch_size=args.batch_size * 2, shuffle=False, num_workers=2,
                               collate_fn=lambda b: collate_classification(b, tok, args.max_len))

    # supcon_loss_fn = SupervisedContrastiveLoss(temperature=args.supcon_temp)
    # best_silhouette = -1.0
    # best_loss = 100

    # for epoch in range(1, args.epochs_contrastive + 1):
    #     print(f"\n--- Contrastive Epoch {epoch}/{args.epochs_contrastive} ---")
    #     model.train()
    #     pbar = tqdm(train_dl_simple, desc=f"[Stage 1 Epoch {epoch}] Training")
        # running_loss = 0.0
        # num_batches = 0
    #     for enc, y in pbar:
    #         ids = enc["input_ids"].to(device, non_blocking=True)
    #         att = enc["attention_mask"].to(device, non_blocking=True)
            
    #         y = y.to(device, non_blocking=True)
            
    #         with torch.amp.autocast("cuda", dtype=amp_dtype, enabled=(device.type=="cuda")):
    #             outputs = model(input_ids=ids, attention_mask=att)
    #             z = outputs["projection_embedding"]
    #             loss = supcon_loss_fn(z, y)
            
    #         optim_proj.zero_grad(set_to_none=True)
    #         loss.backward()
    #         torch.nn.utils.clip_grad_norm_(model.projection_head.parameters(), 5.0)
    #         optim_proj.step()

    #         pbar.set_postfix({"supcon_loss": loss.item()})
    #         wandb.log({"train_contrastive/step_loss": loss.item()})
    #         stats = batch_cosine_stats(z.detach(), y.detach())
    #         log_batch_stats_wandb(stats, prefix="sim/")
    #         with torch.no_grad():
    #             uniq, cnt = torch.unique(y, return_counts=True)
    #             min_cnt = cnt.min().item()
    #         wandb.log({"batch/min_count_per_class": min_cnt})
        #     running_loss += loss.item()
        #     num_batches += 1

        # if num_batches > 0:
        #     avg_loss = running_loss / num_batches
        # else:
        #     avg_loss = float("nan")
        # wandb.log({"train_contrastive/epoch_loss": avg_loss, "epoch_contrastive": epoch})

    #     # 3. 评估可分性
    #     metrics = evaluate_separability(model, val_dl_simple, device, amp_dtype)
    #     sil_score = metrics["silhouette_cosine"]
    #     print(f"[Stage 1 Epoch {epoch}] Val Silhouette Score: {sil_score:.4f}")
    #     wandb.log({"eval_contrastive/silhouette": sil_score, "epoch_contrastive": epoch})


    #     torch.save(model.state_dict(), os.path.join(args.output_dir, f"epoch_{epoch}_projection.pt"))
        # if sil_score > best_silhouette:
        #     best_silhouette = sil_score
        #     print(f"[Stage 1] New best silhouette score {sil_score:.4f}. Saving projection head to best_projection.pt")

    # =================================================================
    # == STAGE 2: Train Classification Head                           ==
    # =================================================================
    print("\n" + "="*60)
    print("== STAGE 2: Training Classification Head ==")
    print("="*60)

    if args.pre_flip_enable:
        if not args.pre_flip_proj_ckpt or not os.path.exists(args.pre_flip_proj_ckpt):
            raise ValueError("--pre_flip_enable 已开启，但未提供或找不到 --pre_flip_proj_ckpt（请给 best_projection.pt）")

        print("[PreFlip] loading projection head:", args.pre_flip_proj_ckpt)
        state = torch.load(args.pre_flip_proj_ckpt, map_location="cpu")
        try:
            model.load_state_dict(state, strict=True)
        except Exception:
            if isinstance(state, dict) and "projection_head" in "".join(state.keys()):
                missing, unexpected = model.load_state_dict(state, strict=False)
                print("[PreFlip] loaded with relaxed strict=False;", "missing:", len(missing), "unexpected:", len(unexpected))
            else:
                raise

        exs_train = read_jsonl(args.train_path)
        if not exs_train:
            raise RuntimeError(f"[PreFlip] 空训练集：{args.train_path}")

        Z = encode_texts_projection(model, tok, exs_train, device,
                                    batch_size=128, max_len=args.max_len,
                                    amp_dtype=(torch.bfloat16 if args.bf16 else None),
                                    text_key="prompt")

        labels = np.array([int(ex["label"]) for ex in exs_train], dtype=np.int64)
        new_y, st = neighbor_flip_only_pos_to_neg(
            embeddings=Z, labels=labels, k=args.pre_flip_k, min_sim=args.pre_flip_min_sim
        )
        print(f"[PreFlip] stats: {st}")

        out_p = args.pre_flip_out or (args.train_path + ".flip.jsonl")
        for ex, ny in zip(exs_train, new_y.tolist()):
            ex["label"] = int(ny)
        write_jsonl(out_p, exs_train)
        print(f"[PreFlip] saved flipped training set -> {out_p}")

        args.train_path = out_p

    if args.hardflip_enable:
        src_all = args.hardflip_all_data_path or args.train_path
        all_zero_prompts = collect_all_zero_prompt_ids(src_all)
        if not all_zero_prompts:
            print("[HardFlip] Warning: 未统计到任何 all-zero prompt（请确认多模型对齐数据）。")

        exs_train = read_jsonl(args.train_path)
        if not exs_train:
            raise RuntimeError(f"[HardFlip] 空训练集：{args.train_path}")

        key_pid = "prompt_id"
        flip_0_to_1, keep_1, total = 0, 0, 0
        for ex in exs_train:
            pid = int(ex.get(key_pid, -1))
            if pid < 0:
                # 没有 prompt_id 的样本：保持原样
                continue

            y = int(ex.get("label", 0))
            if (pid in all_zero_prompts) and (y == 0):
                ex["label"] = 1     # 只做 0->1 翻转
                flip_0_to_1 += 1
            else:
                # 其他情况：保持原来的标签（包括 y=1 的样本不变）
                ex["label"] = y
                if y == 1:
                    keep_1 += 1
            total += 1

        print(f"[HardFlip] effected: flip_0_to_1={flip_0_to_1}, keep_original_1={keep_1}, total_seen={total}")

        out_p = args.hardflip_out or (args.train_path + ".hard.jsonl")
        write_jsonl(out_p, exs_train)
        args.train_path = out_p

        # 可选：观察最终用于训练（按 model_id 过滤后）的分布
        tmp_rows = JsonlRows(args.train_path, args.model_id, None).rows
        if tmp_rows:
            pos = sum(r["label"] for r in tmp_rows); neg = len(tmp_rows) - pos
            print(f"[HardFlip] after relabel (model_id={args.model_id}): pos={pos}, neg={neg}, pos_ratio={pos/len(tmp_rows):.3f}")

    if args.smooth_enable:
        if not args.smooth_proj_ckpt or not os.path.exists(args.smooth_proj_ckpt):
            raise ValueError("--smooth_enable 启用但缺少 --smooth_proj_ckpt (best_projection.pt)")

        src_all = args.smooth_all_data_path or args.train_path
        all_zero_prompts = collect_all_zero_prompt_ids(src_all)
        if not all_zero_prompts:
            print("[Smooth] Warning: 未统计到任何 all-zero prompt（可能缺少 prompt_id 或数据仅包含单模型）。")

        exs_train = read_jsonl(args.train_path)
        exs_small = []
        for ex in exs_train:
            mid = int(ex.get("model_id", args.model_id))
            if mid == args.model_id:
                exs_small.append(ex)
        if not exs_small:
            raise RuntimeError(f"[Smooth] 当前模型(model_id={args.model_id})的训练样本为空。")

        print("[Smooth] loading projection head:", args.smooth_proj_ckpt)
        state = torch.load(args.smooth_proj_ckpt, map_location="cpu")
        try:
            model.load_state_dict(state, strict=True)
        except Exception:
            missing, unexpected = model.load_state_dict(state, strict=False)
            print("[Smooth] relaxed load; missing:", len(missing), "unexpected:", len(unexpected))

        Z = encode_texts_projection(model, tok, exs_small, device,
                                    batch_size=128, max_len=args.max_len,
                                    amp_dtype=(torch.bfloat16 if args.bf16 else None),
                                    text_key="prompt")
        y_small = np.array([int(ex.get("label", 0)) for ex in exs_small], dtype=np.int64)

        amb_mask = mark_ambiguous_pos_by_neighbors(
            embeddings=Z, labels=y_small,
            k=args.smooth_k, min_sim=args.smooth_min_sim
        )

        low, high = float(args.smooth_pos_low), float(args.smooth_pos_high)
        rng = np.random.default_rng(42)

        pid_key = "prompt_id"
        out_p = args.smooth_out or (args.train_path + ".smooth.jsonl")
        with open(out_p, "w", encoding="utf-8") as f:

            for ex in exs_train:
                mid = int(ex.get("model_id", args.model_id))
                if mid != args.model_id:
                    f.write(json.dumps(ex, ensure_ascii=False) + "\n")

            for i, ex in enumerate(exs_small):
                lab = int(ex.get("label", 0))
                pid = int(ex.get(pid_key, -1))

                if pid >= 0 and pid in all_zero_prompts:
                    ex["label_smooth"] = 0.0

                elif lab == 1 and amb_mask[i]:
                    ex["label_smooth"] = float(rng.uniform(low, high))
                else:

                    ex["label_smooth"] = float(lab)

                f.write(json.dumps(ex, ensure_ascii=False) + "\n")

        print(f"[Smooth] saved smoothed training set -> {out_p}")
        args.train_path = out_p  # 让后续直接用平滑后的数据


    print("[Stage 2] Loading best projection head and freezing it.")
    best_proj_path = os.path.join(args.output_dir, "best_projection.pt")
    if not os.path.exists(best_proj_path):
        sys.exit("Error: Best projection head not found. Stage 1 might have failed.")
    model.load_state_dict(torch.load(best_proj_path))
    model.requires_grad_(False)
    model.eval()

    tr_rows = JsonlRows(args.train_path, args.model_id, None).rows
    if not tr_rows:
        sys.exit(f"[Stage 2] Empty training set after pre-flip: {args.train_path}")
    print(f"[Stage 2] Using training data: {args.train_path} | N={len(tr_rows)} | "
        f"pos={sum(r['label'] for r in tr_rows)}")

    clf_head = MLPHead(args.projection_dim).to(device)
    optim_clf = torch.optim.AdamW(clf_head.parameters(), lr=args.lr_classifier)
    
    train_ds_simple = SimpleDataset(tr_rows)
    train_dl_simple = DataLoader(
        train_ds_simple,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=2,
        collate_fn=lambda b: collate_classification(b, tok, args.max_len),
    )

    from transformers import get_linear_schedule_with_warmup

    num_steps_per_epoch = max(1, len(train_dl_simple))
    total_steps = args.epochs_classifier * num_steps_per_epoch

    # scheduler
    scheduler = None
    if args.scheduler == "onecycle":
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optim_clf, max_lr=args.lr_classifier,
            total_steps=total_steps, pct_start=args.warmup_ratio,
            anneal_strategy="cos", final_div_factor=1e4
        )
    elif args.scheduler == "cosine":
        # 周期取一个 epoch，重启倍增
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optim_clf, T_0=num_steps_per_epoch, T_mult=2, eta_min=args.min_lr
        )
    elif args.scheduler == "linear_warmup":
        warmup_steps = int(args.warmup_ratio * total_steps)
        scheduler = get_linear_schedule_with_warmup(
            optim_clf, num_warmup_steps=warmup_steps, num_training_steps=total_steps
        )
    # plateau 在 epoch 末调用 .step(metric)，这里先占位
    elif args.scheduler == "plateau":
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optim_clf, mode="max", factor=0.5, patience=2, min_lr=args.min_lr, verbose=True
        )

    # EMA（仅对分类头）
    ema = ModelEMA(clf_head, decay=args.ema_decay) if args.ema_decay > 0 else None

    if args.loss_type == "asym_bce":
        crit = AsymBCEWithLogits(gamma_pos=args.gamma_pos, gamma_neg=args.gamma_neg, fp_cost=args.fp_cost)
    else:
        crit = None

    # 3. 准备分类任务的 DataLoader
    # train_ds_simple = SimpleDataset(tr_rows)
    # train_dl_simple = DataLoader(train_ds_simple, batch_size=args.batch_size, shuffle=True, num_workers=2,
    #                              collate_fn=lambda b: collate_classification(b, tok, args.max_len))
    
    # 处理类别不平衡
    eps = 1e-6
    # pos_ratio = sum(r['label'] for r in tr_rows) / len(tr_rows)
    # # weight for [class 0, class 1]
    # ce_weights = torch.tensor([1.0, (1.0 - pos_ratio) / max(pos_ratio, eps)], dtype=torch.float32).to(device)
    pos_count = sum(r['label'] for r in tr_rows)
    neg_count = len(tr_rows) - pos_count
    class_counts = torch.tensor([max(neg_count, 1), max(pos_count, 1)], dtype=torch.float32)
    ce_weights = class_counts.sum() / (2.0 * class_counts)
    ce_weights = (ce_weights / ce_weights.mean()).to(device)
    print(f"[Stage 2] Using CrossEntropy weights for imbalance: {ce_weights.cpu().numpy()}")

    # === all-zero set for FP analysis (use EVAL reference, not train) ===
    fp_ref_path = args.fp_allzero_ref_path or args.val_path
    all_zero_prompts_eval = collect_all_zero_prompt_ids(fp_ref_path)
    if not all_zero_prompts_eval:
        print("[FP-Ref] Warning: 未在评估参考数据里统计到 all-zero prompts。"
            " 请确认传入的是包含多模型同一 prompt 的 jsonl，并且字段 {prompt_id, model_id, label} 完整。")
    else:
        print(f"[FP-Ref] all_zero_prompts_eval = {len(all_zero_prompts_eval)} from {fp_ref_path}")

    # 4. 训练分类头
    best_f1 = -1.0
    best_prec_score = -1.0
    for epoch in range(1, args.epochs_classifier + 1):
        clf_head.train()
        pbar = tqdm(train_dl_simple, desc=f"[Stage 2 Epoch {epoch}] Training Classifier")
        for enc, y, _ in pbar:
            enc = {k: v.to(device, non_blocking=True) for k, v in enc.items()}
            y = y.to(device, non_blocking=True)
            
            with torch.amp.autocast("cuda", dtype=amp_dtype, enabled=(device.type=="cuda")):
                # 特征提取是固定的
                with torch.no_grad():
                    outputs = model(**enc)
                z = outputs["projection_embedding"]
                logits = clf_head(z)
                # loss = F.cross_entropy(logits, y, weight=ce_weights)

            if args.loss_type == "asym_bce":
                loss = crit(logits, y, ce_weights=ce_weights)
            else:
                # loss = F.cross_entropy(logits, y, weight=ce_weights)
                y_float = y.float()               # collate 返回 long，这里转 float32
                p_logit = logits[:, 1]            # 仅正类 logit
                # 类不平衡：把原ce_weights映射到逐样本权重
                sw = torch.where(y_float >= 0.5, ce_weights[1], ce_weights[0])
                loss = F.binary_cross_entropy_with_logits(p_logit, y_float, weight=sw)

            optim_clf.zero_grad(set_to_none=True)
            loss.backward()
            if args.max_grad_norm and args.max_grad_norm > 0:
                grad_norm = torch.nn.utils.clip_grad_norm_(clf_head.parameters(), args.max_grad_norm)
            else:
                total_norm_sq = 0.0
                for p in clf_head.parameters():
                    if p.grad is None:
                        continue
                    param_norm = p.grad.detach().data.norm(2).item()
                    total_norm_sq += param_norm ** 2
                grad_norm = total_norm_sq ** 0.5 if total_norm_sq > 0 else 0.0
            optim_clf.step()

            # step-level scheduler
            if args.scheduler in ["onecycle","cosine","linear_warmup"]:
                scheduler.step()
            # EMA 更新
            if ema is not None:
                ema.update()

            # optim_clf.zero_grad(set_to_none=True)
            # loss.backward()
            # optim_clf.step()
            # pbar.set_postfix({"ce_loss": loss.item()})
            # wandb.log({"train_classifier/step_loss": loss.item()})
            grad_norm = float(grad_norm)
            lr_logs = {f"train_classifier/lr_group{i}": pg["lr"] for i, pg in enumerate(optim_clf.param_groups)}
            if optim_clf.param_groups:
                lr_logs.setdefault("train_classifier/lr", optim_clf.param_groups[0]["lr"])

            pbar.set_postfix({"ce_loss": loss.item(), "grad_norm": grad_norm})
            wandb.log({
                "train_classifier/step_loss": loss.item(),
                "train_classifier/grad_norm": grad_norm,
                **lr_logs,
            })
        
        # 5. 评估分类器
        if ema is not None:
            ema.apply_shadow()

        metrics = evaluate_classifier(model, clf_head, val_dl_simple, device, amp_dtype)
        print(f"[Stage 2 Epoch {epoch}] Val Metrics: "
              f"F1={metrics['f1']:.4f}, Acc={metrics['acc']:.4f}, AUC={metrics['auc']:.4f}")
        wandb.log({"eval_classifier/f1": metrics['f1'], "eval_classifier/acc": metrics['acc'], "epoch_classifier": epoch,
                   "eval_classifier/Precision": metrics['precision'], "eval_classifier/TP": metrics['TP'], "eval_classifier/FP": metrics['FP']})


        final_model_state = {
                'model_projection_state_dict': model.projection_head.state_dict(),
                'classifier_head_state_dict': clf_head.state_dict(),
                'args': vars(args)
            }
        torch.save(final_model_state, os.path.join(args.output_dir, f"epoch_{epoch}_model.pt"))
        if metrics['f1'] > best_f1:
            best_f1 = metrics['f1']
            print(f"[Stage 2] New best F1 score. Saving best classifier to best_classifier.pt")
            torch.save(final_model_state, os.path.join(args.output_dir, "best_classifier.pt"))

        # 额外：在验证集上选 precision 约束阈值
        # y_val, p_val = collect_probs(model, clf_head, val_dl_simple, device, amp_dtype)
        # thr_star, thr_stat = find_threshold_for_precision(y_val, p_val, prec_target=args.prec_target)
        y_val, s_val, pid_val = collect_scores_and_pids(model, clf_head, val_dl_simple, device, amp_dtype)

        # y_pred = (s_val >= thr_star).astype(int)

        # 分析 FP 难样本占比
        fp_stats = analyze_fp_difficulty(
            y_true=y_val,
            y_score=s_val,
            prompt_ids=pid_val,
            all_zero_prompts=all_zero_prompts_eval,
            out_csv=os.path.join(args.output_dir, "fp_analysis.csv")
        )
        wandb.log({"val/fp_stats": fp_stats})

        # y_val, s_val = collect_scores(model, clf_head, val_dl_simple, device, amp_dtype)
        thr_star, thr_stat = find_threshold_for_precision(y_val, s_val, prec_target=args.prec_target)

        print(f"[Val] precision-constraint thr*: {thr_star:.4f} | "
            f"P={thr_stat['precision']:.4f}, R={thr_stat['recall']:.4f}, Sel={thr_stat['num_selected']}")
        wandb.log({"val/prec_thr": thr_star, "val/prec_at_thr": thr_stat["precision"],
                "val/recall_at_thr": thr_stat["recall"]})

        score_this = thr_stat["recall"] if args.save_best_by=="recall" else \
                    (2*thr_stat["precision"]*thr_stat["recall"] / max(1e-8, thr_stat["precision"]+thr_stat["recall"]))
        if score_this > best_prec_score:
            best_prec_score = score_this
            final_model_state = {
                'model_projection_state_dict': model.projection_head.state_dict(),
                'classifier_head_state_dict':  clf_head.state_dict(),
                'args': vars(args),
                'prec_constrained': {
                    'target': args.prec_target,
                    'thr':    thr_star,
                    'val_precision': thr_stat['precision'],
                    'val_recall':    thr_stat['recall']
                }
            }
            torch.save(final_model_state, os.path.join(args.output_dir, "best_f1_classifier.pt"))
            print(f"[Stage 2] Saved with precision≥{args.prec_target} thr*={thr_star:.4f} "
                f"(P={thr_stat['precision']:.3f}, R={thr_stat['recall']:.3f})")

        if ema is not None:
            ema.restore()

        # plateau 在 epoch 末按监控指标 step（越大越好）
        if args.scheduler == "plateau":
            scheduler.step(score_this)

    print("\n[Done] Two-stage training finished.")
    wandb.finish()


if __name__ == "__main__":
    main()

