#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os, json, argparse, random, datetime
from typing import Dict, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, AutoModel, DataCollatorWithPadding, get_cosine_schedule_with_warmup
from sklearn.metrics import f1_score, precision_score, recall_score, average_precision_score, roc_auc_score
from tqdm import tqdm
import torch.nn.functional as F
import wandb

# ---------------------------
# Utils
# ---------------------------

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# ---------------------------
# Dataset
# ---------------------------

class JsonlMultiLabelDataset(Dataset):
    def __init__(self, path: str, text_key: str = "text", label_key: str = "labels", max_samples: Optional[int] = None):
        self.samples = []
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                s = line.strip()
                if not s: continue
                ex = json.loads(s)
                labels = ex[label_key]
                assert isinstance(labels, list), "labels must be a list of 0/1"
                self.samples.append({"text": ex[text_key], "labels": [int(v) for v in labels]})
                if max_samples and len(self.samples) >= max_samples: break
        if not self.samples:
            raise ValueError(f"No samples in {path}")
        self.num_labels = len(self.samples[0]["labels"])
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx): return self.samples[idx]

def collate_fn(batch, tok, max_len):
    texts = [b["text"] for b in batch]
    labels = [b["labels"] for b in batch]
    enc = tok(texts, truncation=True, max_length=max_len, padding=True, return_tensors="pt")
    enc["labels"] = torch.tensor(labels, dtype=torch.float32)  # 强制 float32，避免 bf16 all_gather 报错
    return enc

# ---------------------------
# Model: HF encoder + Transformer Decoder head
# ---------------------------

class DecoderHead(nn.Module):
    """标签作为查询；文本 token 作为 memory；输出形状 (B, L)。"""
    def __init__(self, d_model: int, num_labels: int, n_layers: int = 2, n_heads: int = 8, dropout: float = 0.1, ffn_mult: int = 4):
        super().__init__()
        encoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=d_model*ffn_mult,
            dropout=dropout, batch_first=True, activation="gelu"
        )
        self.decoder = nn.TransformerDecoder(encoder_layer, num_layers=n_layers)
        self.label_queries = nn.Embedding(num_labels, d_model)
        self.classifier = nn.Linear(d_model, 1)  # 对每个标签 query 输出一个 logit

    def forward(self, memory, memory_key_padding_mask=None):
        # memory: (B, T, d_model) from encoder
        B = memory.size(0)
        L = self.label_queries.num_embeddings
        tgt = self.label_queries.weight.unsqueeze(0).expand(B, L, -1)  # (B, L, d)
        # TransformerDecoder 的 key padding mask: True=mask
        out = self.decoder(tgt=tgt, memory=memory, memory_key_padding_mask=memory_key_padding_mask)
        logits = self.classifier(out).squeeze(-1)  # (B, L)
        return logits

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

class EasyHead(nn.Module):
    def __init__(self, hidden_size, num_labels=5):
        super().__init__()
        self.classifier = nn.Linear(hidden_size, num_labels)
    def forward(self, x):
        return self.classifier(x)

class LowRankBilinearHead(nn.Module):
    def __init__(self, hidden_size, num_labels, rank=64, p=0.1):
        super().__init__()
        self.pre = nn.Sequential(nn.LayerNorm(hidden_size), nn.Dropout(p))
        self.U = nn.Linear(hidden_size, rank, bias=False)     # x -> r
        self.V = nn.Parameter(torch.randn(num_labels, rank))  # 每类一个 r 维向量
        nn.init.normal_(self.V, std=0.02)
        self.bias = nn.Parameter(torch.zeros(num_labels))
    def forward(self, x):
        z = self.U(self.pre(x))          # [B, r]
        logits = z @ self.V.T + self.bias
        return logits

class MLPHead(nn.Module):
    def __init__(self, hidden_size, num_labels, width=4, p=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(hidden_size),
            nn.Dropout(p),
            nn.Linear(hidden_size, hidden_size*width),
            nn.GELU(),
            nn.Dropout(p),
            nn.Linear(hidden_size*width, num_labels),
        )
    def forward(self, x):
        return self.net(x)
    
class CosineHead(nn.Module):
    def __init__(self, hidden_size, num_labels, per_class_scale=False):
        super().__init__()
        self.W = nn.Parameter(torch.empty(num_labels, hidden_size))
        nn.init.kaiming_uniform_(self.W)
        self.scale = nn.Parameter(torch.ones(num_labels if per_class_scale else 1) * 10.0)
        self.bias  = nn.Parameter(torch.zeros(num_labels))
    def forward(self, x):
        x = F.normalize(x, dim=-1)
        W = F.normalize(self.W, dim=-1)
        logits = (x @ W.T) * self.scale + self.bias
        return logits

class HFDecoderMultiLabel(nn.Module):
    """冻结或微调 HF 编码器，接 TransformerDecoder 做多标签分类。"""
    def __init__(self, model_name: str, num_labels: int, freeze_backbone: bool = True):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(model_name)  # MiniLM-L6-v2
        d_model = getattr(self.backbone.config, "hidden_size", None)
        if d_model is None:
            raise ValueError("Cannot infer hidden_size from backbone config.")
        if freeze_backbone:
            for p in self.backbone.parameters(): p.requires_grad = False
        # self.head = EasyHead(self.backbone.config.hidden_size, num_labels=num_labels)
        # self.head = LowRankBilinearHead(self.backbone.config.hidden_size, num_labels=num_labels)
        self.head = MLPHead(self.backbone.config.hidden_size, num_labels=num_labels)
        # self.head = CosineHead(self.backbone.config.hidden_size, num_labels=num_labels)
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.num_labels = num_labels

    def forward(self, input_ids=None, attention_mask=None, labels=None):
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        out = mean_pooling(out, attention_mask)
        # Normalize embeddings
        # out = F.normalize(out, p=2, dim=1)
        logits = self.head(out)  # (B, L)
        loss = None
        if labels is not None:
            loss = self.loss_fn(logits.to(labels.dtype), labels)
        return {"loss": loss, "logits": logits}

# ---------------------------
# 分布式与评估
# ---------------------------

def make_sample_weights(ds, power=1.0, clip_min=0.2, clip_max=5.0, eps=1e-6):
    Y = np.array([s["labels"] for s in ds.samples], dtype=np.float32)  # (N,L)∈{0,1}
    freq = Y.mean(0)                                                    # 每类正例频率
    inv  = 1.0 / np.clip(freq, eps, None)                               # 逆频率
    w = (Y * inv).sum(1)                                                # 样本权重 = 标签逆频率求和
    w = np.power(w, power)
    w = np.clip(w, clip_min, clip_max)
    return torch.as_tensor(w, dtype=torch.double)


def init_distributed_if_needed():
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        backend = "nccl" if torch.cuda.is_available() else "gloo"
        dist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=7200))
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
        if torch.cuda.is_available():
            torch.cuda.set_device(local_rank)
        return True
    return False

def get_rank(): return int(os.environ.get("RANK", "0"))
def is_main():  return get_rank() == 0

def _gather_numpy_ndarray(x: np.ndarray, device: torch.device):
    if not dist.is_available() or not dist.is_initialized():
        return x
    local_len = np.array([x.shape[0]], dtype=np.int64)
    local_len_t = torch.from_numpy(local_len).to(device)
    world_size = dist.get_world_size()
    lens_t = [torch.zeros_like(local_len_t) for _ in range(world_size)]
    dist.all_gather(lens_t, local_len_t)
    lens = [int(t.cpu().numpy()[0]) for t in lens_t]
    max_len = max(lens)
    pad_width = [(0, max_len - x.shape[0])] + [(0,0)]*(x.ndim-1)
    x_pad = np.pad(x, pad_width, mode="constant")
    x_pad_t = torch.from_numpy(x_pad).to(device)
    gathered = [torch.zeros_like(x_pad_t) for _ in range(world_size)]
    dist.all_gather(gathered, x_pad_t)
    if dist.get_rank() == 0:
        outs = []
        for arr_t, ln in zip(gathered, lens):
            outs.append(arr_t[:ln].cpu().numpy())
        return np.concatenate(outs, axis=0)
    else:
        return None

@torch.no_grad()
def evaluate_multilabel_sharded(model, loader, device):
    model.eval()
    logits_all, labels_all = [], []
    for batch in loader:
        ids = batch["input_ids"].to(device, non_blocking=True)
        att = batch["attention_mask"].to(device, non_blocking=True)
        labs = batch["labels"].to(device, non_blocking=True)  # float32
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16,
                            enabled=torch.cuda.is_available() and torch.cuda.is_bf16_supported()):
            out = model(input_ids=ids, attention_mask=att)
            logits = out["logits"]
        logits_all.append(logits.float().cpu().numpy())
        labels_all.append(labs.cpu().numpy())

    local_logits = np.concatenate(logits_all, axis=0) if logits_all else np.zeros((0, model.module.num_labels if hasattr(model,"module") else model.num_labels), dtype=np.float32)
    local_labels = np.concatenate(labels_all, axis=0) if labels_all else np.zeros((0, model.module.num_labels if hasattr(model,"module") else model.num_labels), dtype=np.float32)
    gathered_logits = _gather_numpy_ndarray(local_logits, device)
    gathered_labels = _gather_numpy_ndarray(local_labels, device)
    return gathered_logits, gathered_labels

# ---------------------------
# 阈值搜索（以 precision 为目标）
# ---------------------------

class ThresholdSearch:
    def __init__(self, base: float = 0.5, per_class: bool = False, grid_step: float = 0.05):
        self.base = base
        self.per_class = per_class
        self.grid = np.arange(0.05, 0.95 + 1e-9, grid_step)
        self.best_t = base
        self.best_t_per_class = None

    def __call__(self, eval_pred):
        logits, labels = eval_pred
        probs = 1 / (1 + np.exp(-logits))
        y_true = labels.astype(int)
        L = y_true.shape[1]

        if not self.per_class:
            best_prec, best_t = -1.0, self.base
            for t in self.grid:
                y_pred = (probs >= t).astype(int)
                prec = precision_score(y_true, y_pred, average="micro", zero_division=0)
                if prec > best_prec:
                    best_prec, best_t = prec, t
            self.best_t = float(best_t)
            y_pred = (probs >= self.best_t).astype(int)
        else:
            best_tpc = np.full(L, self.base, dtype=np.float32)
            for j in range(L):
                best_prec_j, best_t_j = -1.0, self.base
                pj, yj = probs[:, j], y_true[:, j]
                for t in self.grid:
                    yhat = (pj >= t).astype(int)
                    precj = precision_score(yj, yhat, average="binary", zero_division=0)
                    if precj > best_prec_j:
                        best_prec_j, best_t_j = precj, t
                best_tpc[j] = best_t_j
            self.best_t_per_class = best_tpc
            y_pred = (probs >= best_tpc.reshape(1, -1)).astype(int)

        # 其它指标一并报出（方便观察）
        metrics = {
            "val/precision_micro": precision_score(y_true, y_pred, average="micro", zero_division=0),
            "val/recall_micro": recall_score(y_true, y_pred, average="micro", zero_division=0),
            "val/f1_micro": f1_score(y_true, y_pred, average="micro", zero_division=0),
            "val/f1_macro": f1_score(y_true, y_pred, average="macro", zero_division=0),
        }
        try:
            metrics["val/rocauc_macro"] = roc_auc_score(y_true, 1/(1+np.exp(-logits)), average="macro")
        except Exception:
            metrics["val/rocauc_macro"] = float("nan")
        try:
            metrics["val/prauc_macro"] = average_precision_score(y_true, 1/(1+np.exp(-logits)), average="macro")
        except Exception:
            metrics["val/prauc_macro"] = float("nan")

        if self.per_class:
            metrics["val/thresholds_json"] = json.dumps([float(x) for x in self.best_t_per_class])
        else:
            metrics["val/threshold"] = float(self.best_t)
        return metrics


def pick_route_from_probs(p, thr=0.5, default_id=0):
    if np.isscalar(thr):
        thr_vec = np.full(p.shape[1], thr, dtype=np.float32)
    else:
        thr_vec = np.asarray(thr, dtype=np.float32)
    margin = p - thr_vec.reshape(1, -1)
    cand = margin > 0
    picks = np.full(p.shape[0], default_id, dtype=np.int64)
    has = cand.any(axis=1)
    picks[has] = np.argmax(np.where(cand[has], p[has], -np.inf), axis=1)
    return picks, has


# ---------------------------
# 训练主循环
# ---------------------------

def train_torch(args):
    distributed = init_distributed_if_needed()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if distributed:
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
        device = torch.device("cuda", local_rank)
        torch.cuda.set_device(local_rank)

    torch.backends.cuda.matmul.allow_tf32 = True
    torch.set_float32_matmul_precision("high")
    set_seed(args.seed)

    if is_main() and args.wandb:
        # 您可以自定义 run 的名字，或者留空让 wandb 自动生成
        run_name = f"{os.path.basename(args.model_name)}-e{args.epochs}-b{args.batch_size*args.grad_accum}-lr{args.lr}"
        wandb.init(
            project=args.wandb_proj,
            name=run_name,
            config=args, # 自动记录所有 argparse 的超参数
        )

    # tokenizer & datasets
    tok = AutoTokenizer.from_pretrained(args.model_name)
    if tok.pad_token is None and hasattr(tok, "eos_token"):
        tok.pad_token = tok.eos_token

    train_base = JsonlMultiLabelDataset(args.train_path, args.text_key, args.label_key, args.max_train_samples)
    val_base   = JsonlMultiLabelDataset(args.val_path,   args.text_key, args.label_key, args.max_val_samples)

    train_ds = train_base
    val_ds   = val_base

    # model
    model = HFDecoderMultiLabel(
        model_name=args.model_name,
        num_labels=train_base.num_labels,
        freeze_backbone=args.freeze_backbone,
    ).to(device)

    # dataloader
    sampler_gen = None
    if args.balanced_sampling:
        from math import ceil
        w = make_sample_weights(train_ds, args.balance_power, args.balance_clip_min, args.balance_clip_max)
        sampler_gen = torch.Generator()
        if dist.is_available() and dist.is_initialized():
            sampler_gen.manual_seed(args.seed + dist.get_rank())
            per_rank = ceil(len(train_ds) / dist.get_world_size())
        else:
            per_rank = len(train_ds)
        train_sampler = torch.utils.data.WeightedRandomSampler(
            weights=w, num_samples=per_rank, replacement=True, generator=sampler_gen
        )
        shuffle_flag = False
    else:
        train_sampler = DistributedSampler(train_ds, shuffle=True) if distributed else None
        shuffle_flag = (train_sampler is None)

    pad_collate = lambda batch: collate_fn(batch, tok, args.max_len)
    # train_sampler = DistributedSampler(train_ds, shuffle=True) if distributed else None
    val_sampler   = DistributedSampler(val_ds,   shuffle=False) if distributed else None

    train_loader  = DataLoader(train_ds, batch_size=args.batch_size, shuffle=shuffle_flag,
                               sampler=train_sampler, num_workers=args.num_workers, pin_memory=True, collate_fn=pad_collate)
    val_loader    = DataLoader(val_ds, batch_size=args.eval_batch_size or args.batch_size, shuffle=False,
                               sampler=val_sampler, num_workers=args.num_workers, pin_memory=True, collate_fn=pad_collate)

    if distributed:
        from torch.nn.parallel import DistributedDataParallel as DDP
        model = DDP(model, device_ids=[device.index] if device.type=='cuda' else None, find_unused_parameters=True)

    if is_main() and args.wandb:
        # log="all" 会记录梯度和参数的直方图
        wandb.watch(model, log="all", log_freq=10)

    optim = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay)
    num_update_steps_per_epoch = max(1, len(train_loader) // max(1, args.grad_accum))
    total_steps = num_update_steps_per_epoch * args.epochs
    num_warmup_steps = int(total_steps * args.warmup_ratio)
    scheduler = get_cosine_schedule_with_warmup(optim, num_warmup_steps, total_steps)

    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()

    if is_main():
        os.makedirs(args.output_dir, exist_ok=True)
        with open(os.path.join(args.output_dir, "hyperparams_config.jsonl"), "w") as f:
            f.write(json.dumps(vars(args)))

    global_step = 0
    best_prec = -1.0
    bad_epochs = 0
    best_thr = args.threshold
    best_thr_vec = None

    for ep in range(1, args.epochs + 1):
        if distributed and isinstance(train_loader.sampler, torch.utils.data.distributed.DistributedSampler):
            train_loader.sampler.set_epoch(ep)
        elif isinstance(train_loader.sampler, torch.utils.data.WeightedRandomSampler):
            seed_base = args.seed + ep * (dist.get_world_size() if (dist.is_available() and dist.is_initialized()) else 1)
            rank      = (dist.get_rank() if (dist.is_available() and dist.is_initialized()) else 0)
            train_loader.sampler.generator.manual_seed(seed_base + rank)

        model.train()
        optim.zero_grad(set_to_none=True)
        progress_bar = tqdm(enumerate(train_loader, 1), total=len(train_loader), disable=not is_main(), ncols=100)

        for step, batch in progress_bar:
            ids = batch["input_ids"].to(device, non_blocking=True)
            att = batch["attention_mask"].to(device, non_blocking=True)
            labs = batch["labels"].to(device, non_blocking=True)

            with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=use_bf16):
                out = model(input_ids=ids, attention_mask=att, labels=labs)
                loss = out["loss"] / args.grad_accum

            loss.backward()
            if step % args.grad_accum == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optim.step(); scheduler.step(); optim.zero_grad(set_to_none=True)
                global_step += 1
                progress_bar.set_postfix({"loss": f"{loss.item()*args.grad_accum:.4f}", "lr": f"{scheduler.get_last_lr()[0]:.2e}"})

                if is_main() and args.wandb:
                    wandb.log({
                        "train/step_loss": loss.item() * args.grad_accum,
                        "train/learning_rate": scheduler.get_last_lr()[0]
                    }, step=global_step)

        if distributed: dist.barrier()
        logits, labels = evaluate_multilabel_sharded(model, val_loader, device)

        if is_main() and (logits is not None) and (labels is not None):
            probs = 1.0 / (1.0 + np.exp(-logits))
            # searcher = ThresholdSearch(
            #     base=args.threshold,
            #     per_class=args.per_class_thresholds,
            #     grid_step=args.thr_grid_step,
            # )
            # metrics = searcher((logits, labels))
            y_true = labels.astype(int)
            y_pred = (probs >= 0.5).astype(int)

            metrics = {
                "val/precision_micro@0.5": precision_score(y_true, y_pred, average="micro", zero_division=0),
                "val/recall_micro@0.5":    recall_score(y_true, y_pred, average="micro", zero_division=0),
                "val/f1_micro@0.5":        f1_score(y_true, y_pred, average="micro", zero_division=0),
                "val/f1_macro@0.5":        f1_score(y_true, y_pred, average="macro", zero_division=0),
            }
            try:
                metrics["val/rocauc_macro"] = roc_auc_score(y_true, probs, average="macro")
            except Exception:
                metrics["val/rocauc_macro"] = float("nan")
            try:
                metrics["val/prauc_macro"] = average_precision_score(y_true, probs, average="macro")
            except Exception:
                metrics["val/prauc_macro"] = float("nan")
            
            if args.per_class_thresholds:
                thr_vec = np.array(json.loads(metrics["val/thresholds_json"]), dtype=np.float32)
                preds = (probs >= thr_vec.reshape(1, -1)).astype(int)
                best_thr_vec = thr_vec
                best_thr = None
            else:
                thr = 0.5
                preds = (probs >= thr).astype(int)
                best_thr = thr
                best_thr_vec = None

            metrics["val_subset_accuracy@0.5"] = float((preds == y_true).all(axis=1).mean())
            route_idx, has_cand = pick_route_from_probs(probs, thr=0.5, default_id=0)
            metrics["val_routing_accuracy@0.5"] = float(y_true[np.arange(y_true.shape[0]), route_idx].mean())
            metrics["val_abstain_rate@0.5"] = float((~has_cand).mean())

            # subset-accuracy(严格一致) & routing-accuracy（本阶段）
            # subset_acc = (preds == labels).all(axis=1).mean()
            # route_idx = probs.argmax(axis=1)
            # routing_acc = labels[np.arange(labels.shape[0]), route_idx].mean()
            # metrics["val_subset_accuracy"] = float(subset_acc)
            # metrics["val_routing_accuracy"] = float(routing_acc)

            if args.wandb:
                log_metrics = {k: (float(v) if isinstance(v, (np.floating,)) else v) for k, v in metrics.items()}
                log_metrics["epoch"] = ep
                wandb.log(log_metrics, step=global_step)

            output_dir_ep = os.path.join(args.output_dir, f"save_epoch_{ep}")
            os.makedirs(output_dir_ep, exist_ok=True)
            save_ep = model.module if hasattr(model, "module") else model
            torch.save(save_ep.state_dict(), os.path.join(output_dir_ep, f"ckpt_epoch_{ep}.pt"))
            with open(os.path.join(output_dir_ep, "best_thresholds.json"), "w") as f:
                if args.per_class_thresholds:
                    json.dump([float(x) for x in best_thr_vec.tolist()], f)
                else:
                    json.dump(str(best_thr), f)
            improved = metrics["val/precision_micro@0.5"] > best_prec + 1e-6
            if improved:
                best_prec = metrics["val/precision_micro@0.5"]; bad_epochs = 0

                if args.wandb:
                    wandb.summary["best_precision_micro@0.5"] = best_prec
                    wandb.summary["best_epoch"] = ep

                to_save = model.module if hasattr(model, "module") else model
                torch.save(to_save.state_dict(), os.path.join(args.output_dir, "best.pt"))
                tok.save_pretrained(args.output_dir)
                checkpoint = {
                    'epoch': ep,
                    'global_step': global_step,
                    'model_state_dict': to_save.state_dict(),
                    'optimizer_state_dict': optim.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_precision': best_prec,
                }
                output_dir = os.path.join(args.output_dir, f"ckpt_epoch_{ep}")
                os.makedirs(output_dir, exist_ok=True)
                torch.save(checkpoint, os.path.join(output_dir, "best_checkpoint.pt"))
                with open(os.path.join(args.output_dir, "meta.json"), "w", encoding="utf-8") as f:
                    json.dump({"num_labels": train_base.num_labels, "model_name": args.model_name}, f)
                # 阈值
                if best_thr_vec is not None:
                    with open(os.path.join(args.output_dir, "best_thresholds.json"), "w") as f:
                        json.dump([float(x) for x in best_thr_vec.tolist()], f)
                else:
                    with open(os.path.join(args.output_dir, "best_threshold.txt"), "w") as f:
                        f.write(str(best_thr))
            else:
                bad_epochs += 1

            print({k: (float(v) if isinstance(v, (np.floating,)) else v) for k,v in metrics.items()})

            # if bad_epochs >= args.patience:
            #     print(f"[EarlyStop] epoch={ep}, best_precision_micro={best_prec:.4f}")
            #     break

    if is_main() and args.wandb:
        wandb.finish()

    if dist.is_available() and dist.is_initialized():
        dist.destroy_process_group()


def load_state_dict_from_dir(ckpt_dir: str) -> Dict[str, torch.Tensor]:
    best_pt = os.path.join(ckpt_dir, "best.pt")
    if os.path.exists(best_pt):
        return torch.load(best_pt, map_location="cpu")
    binp = os.path.join(ckpt_dir, "pytorch_model.bin")
    if os.path.exists(binp):
        return torch.load(binp, map_location="cpu")
    raise FileNotFoundError(f"no weights found in {ckpt_dir}")

@torch.no_grad()
def predict_torch(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tok = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)
    if tok.pad_token is None and hasattr(tok, "eos_token"):
        tok.pad_token = tok.eos_token

    test_base = JsonlMultiLabelDataset(args.test_path, args.text_key, args.label_key)
    pad_collate = lambda batch: collate_fn(batch, tok, args.max_len)
    test_loader = DataLoader(test_base, batch_size=args.eval_batch_size or args.batch_size,
                             shuffle=False, num_workers=args.num_workers, pin_memory=True, collate_fn=pad_collate)

    model = HFDecoderMultiLabel(
        model_name=args.model_name,
        num_labels=test_base.num_labels,
        freeze_backbone=args.freeze_backbone,
        dec_layers=args.dec_layers,
        dec_heads=args.dec_heads,
        dropout=args.dropout,
        ffn_mult=args.ffn_mult,
        use_pos_weight=False
    ).to(device).eval()

    # 加载权重
    sd = load_state_dict_from_dir(args.ckpt or args.output_dir)
    model.load_state_dict(sd, strict=False)

    # 阈值
    use_thr_vec, use_thr = None, args.threshold
    thr_json = os.path.join(args.ckpt or args.output_dir, "best_thresholds.json")
    thr_txt  = os.path.join(args.ckpt or args.output_dir, "best_threshold.txt")
    if os.path.exists(thr_json):
        try:
            use_thr_vec = np.array(json.load(open(thr_json, "r")), dtype=np.float32)
        except Exception:
            use_thr_vec = None
    if use_thr_vec is None:
        if use_thr is None and os.path.exists(thr_txt):
            try: use_thr = float(open(thr_txt, "r").read().strip())
            except Exception: use_thr = 0.5
    if use_thr is None: use_thr = 0.5

    # 推理
    all_logits = []
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16,
                        enabled=torch.cuda.is_available() and torch.cuda.is_bf16_supported()):
        for batch in test_loader:
            ids = batch["input_ids"].to(device, non_blocking=True)
            att = batch["attention_mask"].to(device, non_blocking=True)
            out = model(input_ids=ids, attention_mask=att)
            all_logits.append(out["logits"].float().cpu().numpy())

    logits = np.concatenate(all_logits, axis=0)
    probs  = 1 / (1 + np.exp(-logits))
    if use_thr_vec is not None:
        preds = (probs >= use_thr_vec.reshape(1, -1)).astype(int)
    else:
        preds = (probs >= use_thr).astype(int)

    # routing acc（纯第二阶段）
    route_idx = probs.argmax(axis=1)
    routing_acc = (np.array([row[j] for row, j in zip((np.array(test_base.samples, dtype=object)), route_idx)]) 
                   if False else None)  # 可按需补真实标签计算

    # 写 CSV
    import csv
    os.makedirs(args.output_dir, exist_ok=True)
    outp = os.path.join(args.output_dir, "predictions.csv")
    with open(outp, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow([*(f"prob_{i}" for i in range(test_base.num_labels)), *(f"pred_{i}" for i in range(test_base.num_labels))])
        for pr, pb in zip(probs, preds):
            w.writerow([*(f"{x:.6f}" for x in pr), *pb.tolist()])
    print(f"[OK] saved predictions to {outp} (threshold={'vector' if use_thr_vec is not None else f'{use_thr:.2f}'})")


def build_parser():
    p = argparse.ArgumentParser("MiniLM + TransformerDecoder multilabel (precision-driven threshold)")
    # Data
    p.add_argument("--train_path", type=str)
    p.add_argument("--val_path", type=str)
    p.add_argument("--test_path", type=str, default=None)
    p.add_argument("--text_key", type=str, default="text")
    p.add_argument("--label_key", type=str, default="labels")
    p.add_argument("--max_train_samples", type=int, default=None)
    p.add_argument("--max_val_samples", type=int, default=None)

    # Model
    p.add_argument("--model_name", type=str, default="sentence-transformers/all-MiniLM-L6-v2")
    p.add_argument("--freeze_backbone", action="store_true", default=False)
    p.add_argument("--dec_layers", type=int, default=2)
    p.add_argument("--dec_heads", type=int, default=8)
    p.add_argument("--ffn_mult", type=int, default=4)
    p.add_argument("--dropout", type=float, default=0.1)

    # Train
    p.add_argument("--output_dir", type=str, default="./outputs_minilm_decoder")
    p.add_argument("--epochs", type=int, default=5)
    p.add_argument("--batch_size", type=int, default=16)
    p.add_argument("--eval_batch_size", type=int, default=None)
    p.add_argument("--max_len", type=int, default=256)
    p.add_argument("--lr", type=float, default=2e-5)
    p.add_argument("--weight_decay", type=float, default=0.01)
    p.add_argument("--warmup_ratio", type=float, default=0.06)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--num_workers", type=int, default=4)
    p.add_argument("--patience", type=int, default=3)
    p.add_argument("--grad_accum", type=int, default=8)

    # Loss & threshold
    p.add_argument("--use_pos_weight", action="store_true")
    p.add_argument("--threshold", type=float, default=0.5)
    p.add_argument("--per_class_thresholds", action="store_true")
    p.add_argument("--thr_grid_step", type=float, default=0.05)

    # Predict
    p.add_argument("--predict", action="store_true")
    p.add_argument("--ckpt", type=str, default=None)

    # WandB
    p.add_argument("--wandb", action="store_true")
    p.add_argument("--wandb_proj", type=str, default="textclf")
    p.add_argument("--logging_steps", type=int, default=10, help="Log training metrics every N update steps.")

    p.add_argument("--balanced_sampling", action="store_true")
    p.add_argument("--balance_power", type=float, default=1.0)   # 权重幂次, 0.5~1.0
    p.add_argument("--balance_clip_min", type=float, default=0.2)
    p.add_argument("--balance_clip_max", type=float, default=5.0)

    return p

if __name__ == "__main__":
    parser = build_parser(); args = parser.parse_args()
    try:
        if args.predict:
            predict_torch(args)
        else:
            train_torch(args)
    except Exception:
        if dist.is_available() and dist.is_initialized():
            dist.destroy_process_group()
        raise
