#!/usr/bin/env python3
# train_axes_per_usage_qwen2audio.py
# - 一次性编码全部音频
# - mode=usage: 按 usage 分桶训练；输出 out_dir/usage_{name}/sentiment_axis_L{lid}.npy
# - mode=main : 全量合并为一个桶训练；输出 out_dir/main/sentiment_axis_L{lid}.npy
# - 同时导出训练指标汇总（per-usage 或 main）

import os, re, json, argparse, gc
import numpy as np
import torch, librosa, pandas as pd
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from transformers import AutoProcessor

try:
    from transformers import Qwen2AudioForConditionalGeneration
    HAS_QWEN2 = True
except Exception:
    HAS_QWEN2 = False
try:
    from transformers import AutoModel
    HAS_AUTOMODEL = True
except Exception:
    HAS_AUTOMODEL = False

torch.set_grad_enabled(False)

# ---------------- utils ----------------
def unit(x):
    x = np.asarray(x, dtype=np.float32)
    n = float(np.linalg.norm(x))
    return x / (n + 1e-12)

def safe_auc(y, s):
    try:
        return roc_auc_score(y, s)
    except Exception:
        return float("nan")

def sanitize(name: str):
    return re.sub(r"[^A-Za-z0-9._-]+", "_", (name or "unknown")).strip("_").lower()

def _detect_audio_token(processor):
    tok = getattr(getattr(processor, "tokenizer", None), "audio_token", None)
    if tok:
        return tok
    for c in ["<|AUDIO|>", "<|audio|>", "<|Audio|>"]:
        try:
            tid = processor.tokenizer.convert_tokens_to_ids(c)
            if tid is not None and tid != processor.tokenizer.unk_token_id:
                return c
        except Exception:
            pass
    return "<|AUDIO|>"

def safe_get_hidden_states(outputs):
    for key in ("audio_hidden_states","encoder_hidden_states","hidden_states"):
        if hasattr(outputs, key):
            hs = getattr(outputs, key)
            if hs is not None:
                return hs
    lmo = getattr(outputs, "language_model_output", None)
    if lmo is not None and getattr(lmo, "hidden_states", None) is not None:
        return lmo.hidden_states
    raise RuntimeError("hidden_states not found")

def read_waves(paths, sr):
    waves = []
    for p in paths:
        try:
            y, _ = librosa.load(p, sr=sr, mono=True)
        except Exception:
            y = np.zeros(1, dtype=np.float32)
        waves.append(y)
    return waves

# -------------- IO --------------
USAGE_KEYS = ["usage", "usage_cat", "bucket", "group"]

def infer_usage_from_name(path_or_id: str):
    s = path_or_id or ""
    m = re.search(r"usage=([A-Za-z0-9_]+)", s)
    if m:
        return m.group(1)
    # 也兼容 F_usage=prosody_sent=positive 这种
    m2 = re.search(r"usage_?([A-Za-z0-9_]+)", s)
    if m2:
        return m2.group(1)
    return "unknown"

def load_items(files, keep_usages=None, min_per_class=1):
    """
    支持 .jsonl 或 .json（list）。识别字段：
      - 音频：wav_path/path/audio
      - 标签：label/sentiment/sent (positive/negative 或 0/1)
      - usage：usage / usage_cat / bucket / group；缺失时从文件名解析 "usage=xxx"
    过滤：保留 wav 存在、且标签为二分类的样本。
    返回：records(list)、missing(int)
    """
    items = []
    for fp in files:
        if not os.path.exists(fp):
            continue
        if fp.endswith(".jsonl"):
            with open(fp, "r", encoding="utf-8") as f:
                for line in f:
                    line = line.strip()
                    if not line: continue
                    try:
                        items.append(json.loads(line))
                    except Exception:
                        pass
        else:
            data = json.load(open(fp, "r", encoding="utf-8"))
            if isinstance(data, dict):
                data = data.get("items", [])
            items.extend(list(data) if isinstance(data, list) else [])

    recs, miss = [], 0
    for d in items:
        p = d.get("wav_path") or d.get("path") or d.get("audio") or ""
        if not p or not os.path.exists(p):
            miss += 1; continue
        lab = d.get("label", d.get("sentiment", d.get("sent")))
        if isinstance(lab, str):
            lab = lab.strip().lower()
            if lab not in {"positive","negative"}:
                continue
            y = 1 if lab == "positive" else 0
        elif lab in (0,1):
            y = int(lab)
        else:
            continue

        # usage
        u = None
        for k in USAGE_KEYS:
            if k in d and isinstance(d[k], str) and d[k].strip():
                u = d[k].strip()
                break
        if not u:
            u = infer_usage_from_name(p) or infer_usage_from_name(d.get("utt_id",""))
        u = sanitize(u)

        if keep_usages and u not in keep_usages:
            continue

        recs.append({
            "utt_id": d.get("utt_id") or os.path.basename(p),
            "wav_path": p,
            "label": y,
            "usage": u,
        })

    # 丢掉在某 usage 下只有单类或样本太少的桶
    by_u = {}
    for r in recs:
        by_u.setdefault(r["usage"], []).append(r)
    filt = []
    for u, lst in by_u.items():
        ys = np.array([r["label"] for r in lst], dtype=int)
        npos, nneg = int(ys.sum()), len(ys)-int(ys.sum())
        if npos >= min_per_class and nneg >= min_per_class:
            filt.extend(lst)
    return filt, miss

def balance_idx(y, method="downsample", seed=42):
    """返回均衡后的索引数组。"""
    if method == "none":
        return np.arange(len(y))
    rng = np.random.RandomState(seed)
    pos = np.where(y==1)[0]; neg = np.where(y==0)[0]
    if len(pos)==0 or len(neg)==0:
        raise ValueError("Only one class present; cannot balance.")
    if method == "downsample":
        K = min(len(pos), len(neg))
        sel_pos = rng.choice(pos, size=K, replace=False)
        sel_neg = rng.choice(neg, size=K, replace=False)
    else:  # upsample
        K = max(len(pos), len(neg))
        sel_pos = rng.choice(pos, size=K, replace=(len(pos)<K))
        sel_neg = rng.choice(neg, size=K, replace=(len(neg)<K))
    sel = np.concatenate([sel_pos, sel_neg])
    rng.shuffle(sel)
    return sel

# -------------- model --------------
def load_qwen2(model_id, device="cuda"):
    processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
    model, err = None, None
    if HAS_QWEN2:
        try:
            model = Qwen2AudioForConditionalGeneration.from_pretrained(
                model_id, trust_remote_code=True
            ).eval().to(device)
        except Exception as e:
            err = e
    if model is None and HAS_AUTOMODEL:
        model = AutoModel.from_pretrained(model_id, trust_remote_code=True).eval().to(device)
    if model is None:
        raise RuntimeError(f"Failed to load {model_id}: {err!r}")
    return processor, model

def encode_wavs(wav_paths, processor, model, batch_size=8, sr=16000, device="cuda"):
    N = len(wav_paths)
    audio_tok = _detect_audio_token(processor)
    # probe
    y0, _ = librosa.load(wav_paths[0], sr=sr, mono=True)
    probe = processor(audio=[y0], text=[audio_tok], sampling_rate=sr,
                      return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        out0 = model(**probe, output_hidden_states=True)
    hs0 = safe_get_hidden_states(out0)
    if len(hs0) >= 2 and hs0[0].shape[-1] == hs0[1].shape[-1]:
        start_idx = 1; L = len(hs0) - 1
    else:
        start_idx = 0; L = len(hs0)
    H = hs0[start_idx].shape[-1]
    embs = [np.empty((N, H), dtype=np.float32) for _ in range(L)]

    for i in tqdm(range(0, N, batch_size), desc="🎧 Encode"):
        chunk = wav_paths[i:i+batch_size]
        waves = read_waves(chunk, sr)
        inputs = processor(audio=waves, text=[audio_tok]*len(waves),
                           sampling_rate=sr, return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            out = model(**inputs, output_hidden_states=True)
        hs = safe_get_hidden_states(out)
        # 找 mask
        attn = None
        for k in ("attention_mask","audio_attention_mask","input_attention_mask"):
            if k in inputs and inputs[k] is not None:
                attn = inputs[k]; break
        mask = None
        if attn is not None:
            if attn.dim()==3:
                if attn.size(1)==1: attn = attn.squeeze(1)
                elif attn.size(2)==1: attn = attn.squeeze(2)
            if attn.dim()==2:
                mask = attn.unsqueeze(-1).float()
        # 池化
        for l in range(L):
            h = hs[start_idx + l]  # [B,T,H]
            if h.dim()==3:
                if mask is not None and mask.shape[:2]==h.shape[:2]:
                    pooled = (h*mask).sum(1)/mask.sum(1).clamp(min=1e-6)
                else:
                    pooled = h.mean(1)
            else:
                pooled = h.reshape(h.size(0), -1)
            embs[l][i:i+len(chunk)] = pooled.detach().cpu().numpy().astype(np.float32)
        del inputs, out, hs, waves
        if device == "cuda":
            torch.cuda.empty_cache()
    return embs, L, H

# --------- train a single bucket (usage/main) ----------
def train_one_bucket(bucket_name, sel_idx, y_all, embs_by_layer, out_dir, verbose=1):
    """
    对一组样本（索引 sel_idx）进行逐层 LogReg 训练并保存轴。
    返回：该桶每层指标的 list[dict]
    """
    rows = []
    y_u = y_all[sel_idx]
    if y_u.sum()==0 or y_u.sum()==len(y_u):
        if verbose:
            print(f"Skip {bucket_name}: only one class.")
        return rows

    save_dir = os.path.join(out_dir, f"{bucket_name}")
    os.makedirs(save_dir, exist_ok=True)

    L = len(embs_by_layer)
    for lid in range(1, L+1):
        X = embs_by_layer[lid-1][sel_idx]
        clf = LogisticRegression(max_iter=2000, fit_intercept=False).fit(X, y_u)
        axis = clf.coef_[0].astype(np.float32)
        axis = axis / (np.linalg.norm(axis) + 1e-12)

        # 保证正样本打分更大（方向对齐）
        if float((X[y_u==1] @ axis).mean()) < float((X[y_u==0] @ axis).mean()):
            axis = -axis

        np.save(os.path.join(save_dir, f"sentiment_axis_L{lid}.npy"), axis)

        s = X @ axis
        preds = (s > 0).astype(int)
        acc = accuracy_score(y_u, preds)
        f1  = f1_score(y_u, preds) if len(set(y_u))>1 else float("nan")
        auc = safe_auc(y_u, s)
        rows.append({"bucket": bucket_name, "layer": lid, "N": len(y_u),
                     "pos": int(y_u.sum()), "neg": len(y_u)-int(y_u.sum()),
                     "acc": acc, "f1": f1, "auroc": auc})
    if rows and verbose:
        best = max(rows, key=lambda r: r["acc"])
        print(f"✅ {bucket_name} | best L{best['layer']} acc={best['acc']:.4f}")
    return rows

# -------------- main --------------
def main(args):
    os.makedirs(args.out_dir, exist_ok=True)

    # A) 读数据
    keep = set([sanitize(x) for x in args.keep_usages.split(",")]) if args.keep_usages else None
    recs, miss = load_items(args.train_files, keep_usages=keep, min_per_class=args.min_per_class)
    if args.verbose:
        print(f"Loaded {len(recs)} items (missing {miss}).")
    if not recs:
        raise SystemExit("No valid records after filtering.")

    wavs = [r["wav_path"] for r in recs]
    y_all = np.array([r["label"] for r in recs], dtype=int)
    usage_all = [r["usage"] for r in recs]

    # B) 编码（一次性）
    processor, model = load_qwen2(args.model_id, args.device)
    embs_by_layer, L, H = encode_wavs(wavs, processor, model, batch_size=args.batch_size,
                                      sr=args.sample_rate, device=args.device)
    if args.verbose:
        print(f"Derived structure: layers={L}, hidden_size={H}")

    # C) 训练
    rows = []

    if args.mode == "main":
        # 全量合并为一个桶
        idx_all = np.arange(len(y_all), dtype=int)
        try:
            sel = balance_idx(y_all, method=args.balance, seed=args.seed)
            sel_idx = idx_all[sel]
        except ValueError:
            raise SystemExit("Only one class present globally; cannot train main axis.")
        rows += train_one_bucket("main", sel_idx, y_all, embs_by_layer, args.out_dir, verbose=args.verbose)

    else:  # mode == "usage"
        usages = sorted(set(usage_all))
        for u in usages:
            idx = np.array([i for i,uu in enumerate(usage_all) if uu==u], dtype=int)
            y = y_all[idx]
            if y.sum()==0 or y.sum()==len(y):
                if args.verbose: print(f"Skip usage={u}: only one class.")
                continue
            try:
                sel = balance_idx(y, method=args.balance, seed=args.seed)
                sel_idx = idx[sel]
            except ValueError:
                if args.verbose: print(f"Skip usage={u}: cannot balance.")
                continue

            rows += train_one_bucket(f"usage_{u}", sel_idx, y_all, embs_by_layer, args.out_dir, verbose=args.verbose)

    # D) 汇总
    if rows:
        df = pd.DataFrame(rows).sort_values(["bucket","layer"])
        summary_name = "train_summary_main.csv" if args.mode=="main" else "train_summary_per_usage.csv"
        df.to_csv(os.path.join(args.out_dir, summary_name), index=False)
        if args.verbose:
            print(f"Saved summary to {os.path.join(args.out_dir, summary_name)}")
            if args.mode=="main":
                print(f"Axes saved under: {args.out_dir}/main/sentiment_axis_L*.npy")
            else:
                print(f"Axes saved under: {args.out_dir}/usage_*/sentiment_axis_L*.npy")
    else:
        print("No results.")

    # 清理
    del model, processor
    if args.device == "cuda":
        torch.cuda.empty_cache()
    gc.collect()

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--mode", choices=["usage","main"], default="usage",
                    help="usage: 按 usage 分桶训练；main: 全量合并为一个桶训练")
    ap.add_argument("--model_id", required=True, help="如 Qwen/Qwen2-Audio-7B-Instruct")
    ap.add_argument("--train_files", nargs="+", required=True,
                    help="一个或多个 .jsonl/.json（含 wav_path + label/sentiment + usage）")
    ap.add_argument("--out_dir", required=True, help="输出根目录")
    ap.add_argument("--device", default="cuda", choices=["cuda","cpu"])
    ap.add_argument("--sample_rate", type=int, default=16000)
    ap.add_argument("--batch_size", type=int, default=8)
    ap.add_argument("--balance", choices=["none","downsample","upsample"], default="downsample",
                    help="类别均衡方式：usage 模式在每个 usage 内均衡；main 模式在全量上均衡")
    ap.add_argument("--min_per_class", type=int, default=1, help="每个 usage 至少的正/负样本数，低于则跳过")
    ap.add_argument("--keep_usages", default=None, help="仅在 usage 模式下生效：只训练这些 usage（逗号分隔）")
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--verbose", type=int, default=1)
    args = ap.parse_args()
    main(args)
