#!/usr/bin/env python3
# test_qwen_audio_axis.py  (auto 版)
# 模式：
#   main  : 单主轴目录
#   usage : 单 usage 或遍历 usage_root/usage_*
#   mix   : 任意多轴融合（手动指定 --axis name=dir 与 --w name=weight）
#   auto  : 自动发现 usage 轴 + 网格搜权重 + 批量跑（solo/pairwise/all）
#
# 轴目录均应包含：sentiment_axis_L{lid}.npy

import os, re, json, argparse, itertools
import numpy as np
import torch, librosa, pandas as pd
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from transformers import AutoProcessor

try:
    from transformers import Qwen2AudioForConditionalGeneration
    HAS_QWEN2_CLASS = True
except Exception:
    HAS_QWEN2_CLASS = False
try:
    from transformers import AutoModel
    HAS_AUTOMODEL = True
except Exception:
    HAS_AUTOMODEL = False

torch.set_grad_enabled(False)

# ---------- utils ----------
def unit(x: np.ndarray) -> np.ndarray:
    return x / (np.linalg.norm(x) + 1e-12)

def safe_auc(y, s):
    try:
        return roc_auc_score(y, s)
    except Exception:
        return float("nan")

def sanitize(name: str) -> str:
    return re.sub(r"[^A-Za-z0-9._-]+", "_", (name or "run").strip()).lower()

def basename_to_axis_name(path: str) -> str:
    b = sanitize(os.path.basename(path))
    # usage_xxx -> xxx
    if b.startswith("usage_"):
        b = b[len("usage_"):]
    return b or "axis"

# ---------- IO ----------
def load_audio_list(jpath, wav_key="wav_path", label_key="label"):
    data = json.load(open(jpath, "r", encoding="utf-8"))
    ids, wavs, labels = [], [], []
    for d in data:
        p = d.get(wav_key); lab = d.get(label_key)
        if not p or lab is None:
            continue
        if isinstance(lab, str):
            lab = lab.lower()
            if lab not in {"positive","negative"}:
                continue
            y = 1 if lab == "positive" else 0
        else:
            y = int(lab)
        if os.path.exists(p):
            wavs.append(p); labels.append(y)
            ids.append(d.get("utt_id", os.path.basename(p)))
    return ids, wavs, np.array(labels, dtype=int)

def _detect_audio_token(processor):
    tok = getattr(getattr(processor, "tokenizer", None), "audio_token", None)
    if tok:
        return tok
    for c in ["<|AUDIO|>", "<|audio|>", "<|Audio|>"]:
        try:
            tid = processor.tokenizer.convert_tokens_to_ids(c)
            if tid is not None and tid != processor.tokenizer.unk_token_id:
                return c
        except Exception:
            pass
    return "<|AUDIO|>"

def safe_get_hidden_states(outputs):
    for key in ("audio_hidden_states", "encoder_hidden_states", "hidden_states"):
        if hasattr(outputs, key):
            hs = getattr(outputs, key)
            if hs is not None:
                return hs
    lmo = getattr(outputs, "language_model_output", None)
    if lmo is not None and getattr(lmo, "hidden_states", None) is not None:
        return lmo.hidden_states
    raise RuntimeError("hidden_states not found")

def read_waves(paths, sr):
    waves = []
    for p in paths:
        try:
            y, _ = librosa.load(p, sr=sr, mono=True)
        except Exception:
            y = np.zeros(1, dtype=np.float32)
        waves.append(y)
    return waves

def load_qwen2(model_id, device="cuda"):
    processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
    model, err = None, None
    if HAS_QWEN2_CLASS:
        try:
            model = Qwen2AudioForConditionalGeneration.from_pretrained(
                model_id, trust_remote_code=True
            ).eval().to(device)
        except Exception as e:
            err = e
    if model is None and HAS_AUTOMODEL:
        model = AutoModel.from_pretrained(model_id, trust_remote_code=True).eval().to(device)
    if model is None:
        raise RuntimeError(f"Failed to load {model_id}: {err!r}")
    return processor, model

def batch_encode_dynamic(wav_paths, processor, model, batch_size=8, target_sr=16000, device="cuda"):
    N = len(wav_paths)
    audio_tok = _detect_audio_token(processor)

    # probe
    y0, _ = librosa.load(wav_paths[0], sr=target_sr, mono=True)
    probe = processor(audio=[y0], text=[audio_tok], sampling_rate=target_sr,
                      return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        out0 = model(**probe, output_hidden_states=True)
    hs0 = safe_get_hidden_states(out0)
    if len(hs0)>=2 and hs0[0].shape[-1]==hs0[1].shape[-1]:
        start_idx = 1; L = len(hs0) - 1
    else:
        start_idx = 0; L = len(hs0)
    H = hs0[start_idx].shape[-1]
    embs = [np.empty((N, H), dtype=np.float32) for _ in range(L)]

    # encode
    for i in tqdm(range(0, N, batch_size), desc="🎧 Encode"):
        chunk = wav_paths[i:i+batch_size]
        waves = read_waves(chunk, target_sr)
        inputs = processor(audio=waves, text=[audio_tok]*len(waves),
                           sampling_rate=target_sr, return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
        hs = safe_get_hidden_states(outputs)

        # attention mask
        attn = None
        for k in ("attention_mask","audio_attention_mask","input_attention_mask"):
            if k in inputs and inputs[k] is not None:
                attn = inputs[k]; break
        mask = None
        if attn is not None:
            if attn.dim()==3:
                if   attn.size(1)==1: attn = attn.squeeze(1)
                elif attn.size(2)==1: attn = attn.squeeze(2)
            if attn.dim()==2:
                mask = attn.unsqueeze(-1).float()

        for l in range(L):
            h = hs[start_idx + l]
            if h.dim()==3:
                if mask is not None and mask.shape[:2]==h.shape[:2]:
                    pooled = (h*mask).sum(1) / mask.sum(1).clamp(min=1e-6)
                else:
                    pooled = h.mean(1)
            else:
                pooled = h.reshape(h.size(0), -1)
            embs[l][i:i+len(chunk)] = pooled.detach().cpu().numpy().astype(np.float32)

        del inputs, outputs, hs, waves
        if device=="cuda": torch.cuda.empty_cache()
    return embs, L, H

# ---------- axis loading ----------
def load_axis(axis_dir, lid, H):
    p = os.path.join(axis_dir, f"sentiment_axis_L{lid}.npy")
    if not os.path.exists(p):
        return None
    v = np.load(p)
    if v.shape[-1] != H:
        return None
    return unit(v.astype(np.float32))

# ---------- eval ----------
def eval_one_run(run_dir, X_layers, y_true, axis_loader):
    os.makedirs(run_dir, exist_ok=True)
    L = len(X_layers)
    rows = []
    for lid in range(1, L+1):
        u = axis_loader(lid)
        if u is None:
            # 允许该层没有轴（例如某 usage 没有所有层）
            continue
        X = X_layers[lid-1]
        scores = X @ u
        preds  = (scores > 0).astype(int)
        acc = accuracy_score(y_true, preds)
        f1  = f1_score(y_true, preds) if len(set(y_true))>1 else float("nan")
        auc = safe_auc(y_true, scores)
        if acc < 0.5:
            scores = -scores; preds = 1 - preds
            acc = accuracy_score(y_true, preds)
            f1  = f1_score(y_true, preds) if len(set(y_true))>1 else float("nan")
            auc = safe_auc(y_true, scores)

        pd.DataFrame({"score": scores, "pred_label": preds,
                      "true_label": y_true, "correct": (preds==y_true).astype(int)})\
          .to_csv(os.path.join(run_dir, f"proj_L{lid}.csv"), index=False)

        rows.append({"layer": lid, "acc": acc, "f1": f1, "auroc": auc})
        print(f"  L{lid:02d} | Acc {acc:.4f} | F1 {f1:.4f} | AUROC {auc:.4f}")

    if not rows:
        print("No results for this run."); return None
    df = pd.DataFrame(rows).sort_values("layer")
    df.to_csv(os.path.join(run_dir, "summary.csv"), index=False)
    best = df.loc[df["acc"].idxmax()]
    print(f"\n✅ Best L{int(best.layer)} | Acc {best.acc:.4f} | F1 {best.f1:.4f} | AUROC {best.auroc:.4f}")
    print(f"✅ Saved summary to {os.path.join(run_dir,'summary.csv')}")
    return {"best_layer": int(best.layer), "best_acc": float(best.acc),
            "best_f1": float(best.f1), "best_auroc": float(best.auroc)}

# ---------- helpers ----------
def parse_kv_list(kv_list):
    out = {}
    if not kv_list: return out
    for s in kv_list:
        if "=" not in s:
            raise ValueError(f"Expect key=value, got: {s}")
        k, v = s.split("=", 1)
        k = k.strip()
        out[k] = v.strip()
    return out

def parse_float_list(s):
    return [float(x) for x in s.split(",")] if s else []

# -------------------- main --------------------
def main(args):
    os.makedirs(args.out_dir, exist_ok=True)

    # 数据与模型
    utt_ids, wavs, y_true = load_audio_list(args.audio_json, args.wav_key, args.label_key)
    assert len(wavs)>0, "No audio found."
    processor, model = load_qwen2(args.model_id, args.device)
    X_layers, L, H = batch_encode_dynamic(
        wavs, processor, model, batch_size=args.batch_size, target_sr=args.sr, device=args.device
    )
    print(f"🔍 Derived structure: layers={L}, hidden_size={H}")

    # only_layers 过滤
    layer_map = {i+1: i+1 for i in range(L)}
    if args.only_layers:
        keep = []
        for tok in args.only_layers.split(","):
            tok = tok.strip()
            if not tok: continue
            if "-" in tok:
                a,b = tok.split("-",1); keep += list(range(int(a), int(b)+1))
            else:
                keep.append(int(tok))
        keep = sorted(set([l for l in keep if 1<=l<=L]))
        if keep:
            X_layers = [X_layers[i-1] for i in keep]
            L = len(keep)
            layer_map = {i+1: keep[i] for i in range(L)}

    def remap_loader(loader_func):
        return lambda lid: loader_func(layer_map[lid])

    # ------- main -------
    if args.mode == "main":
        assert args.axis_dir_main, "--axis_dir_main is required"
        run_dir = os.path.join(args.out_dir, "main")
        axis_loader = remap_loader(lambda lid: load_axis(args.axis_dir_main, lid, H))
        eval_one_run(run_dir, X_layers, y_true, axis_loader)
        return

    # ------- usage -------
    if args.mode == "usage":
        if args.usage_dir:
            name = sanitize(os.path.basename(args.usage_dir))
            run_dir = os.path.join(args.out_dir, name)
            axis_loader = remap_loader(lambda lid: load_axis(args.usage_dir, lid, H))
            eval_one_run(run_dir, X_layers, y_true, axis_loader)
        else:
            assert args.usage_root, "--usage_root or --usage_dir is required"
            usage_dirs = sorted([d for d in os.listdir(args.usage_root)
                                 if os.path.isdir(os.path.join(args.usage_root,d)) and d.startswith("usage_")])
            rows = []
            for d in usage_dirs:
                udir = os.path.join(args.usage_root, d)
                run_dir = os.path.join(args.out_dir, d)
                axis_loader = remap_loader(lambda lid, _udir=udir: load_axis(_udir, lid, H))
                info = eval_one_run(run_dir, X_layers, y_true, axis_loader)
                if info: rows.append({"usage": d, **info})
            if rows:
                pd.DataFrame(rows).sort_values("best_acc", ascending=False)\
                  .to_csv(os.path.join(args.out_dir, "usage_overview.csv"), index=False)
                print(f"✅ Wrote overview: {os.path.join(args.out_dir,'usage_overview.csv')}")
        return

    # ------- mix（手动多轴） -------
    if args.mode == "mix":
        axis_map = parse_kv_list(args.axis)      # name -> dir
        w_map    = {k: float(v) for k, v in parse_kv_list(args.w).items()} if args.w else {}
        assert axis_map, "At least one --axis name=dir is required for mode=mix"
        for k in axis_map:
            w_map.setdefault(k, 1.0)
        tag = "+".join([f"{sanitize(k)}({w_map[k]})" for k in sorted(axis_map.keys())])
        run_dir = os.path.join(args.out_dir, f"mix_{tag}")

        def mix_loader(lid):
            acc = None
            for name, d in axis_map.items():
                u = load_axis(d, lid, H)
                if u is None:
                    continue
                acc = (u * w_map[name]) if acc is None else (acc + u * w_map[name])
            if acc is None:
                return None
            return unit(acc)

        axis_loader = remap_loader(mix_loader)
        eval_one_run(run_dir, X_layers, y_true, axis_loader)
        return

    # ------- auto（自动发现 + 网格） -------
    if args.mode == "auto":
        assert args.axis_dir_main, "--axis_dir_main is required for mode=auto"
        # 发现 usage 轴
        discovered = {}
        for root in (args.axis_root or []):
            if os.path.isdir(root):
                for d in sorted(os.listdir(root)):
                    full = os.path.join(root, d)
                    if os.path.isdir(full):
                        name = basename_to_axis_name(full)
                        discovered[name] = full
        # 合并手动补充
        if args.axis:
            manual = parse_kv_list(args.axis)  # name=dir
            for k, v in manual.items():
                discovered[k] = v

        if not discovered:
            print("⚠️ No usage axes found. Will run solo(main) only.")

        # 网格
        w_main_grid = parse_float_list(args.w_main_grid) or [1.0]
        w_grids = {name: parse_float_list(v) for name, v in parse_kv_list(args.w_grid).items()} if args.w_grid else {}
        for name in discovered.keys():
            if name not in w_grids:
                w_grids[name] = [1.0]

        # 需要跑的组合类型
        want = set([x.strip() for x in (args.combos or "solo,pairwise,all").split(",")])

        all_runs = []

        # SOLO
        if "solo" in want:
            all_runs.append(("solo_main", {"main": args.axis_dir_main}, {"main": [1.0]}))
            for name, d in discovered.items():
                all_runs.append((f"solo_{name}", {name: d}, {name: [1.0]}))

        # PAIRWISE（main + 每个 usage）
        if "pairwise" in want:
            for name, d in discovered.items():
                for wm in w_main_grid:
                    for wu in w_grids[name]:
                        all_runs.append((
                            f"pair_{name}_wm{wm}_w{name}{wu}",
                            {"main": args.axis_dir_main, name: d},
                            {"main": [wm], name: [wu]}
                        ))
        # TRIPLES（main + 任意两个 usage）
        if "triples" in want and len(discovered) >= 2:
            usage_names_all = sorted(discovered.keys())
            for names in itertools.combinations(usage_names_all, 2):
                weight_lists = [w_grids[n] for n in names]
                for wm in w_main_grid:
                    for ws in itertools.product(*weight_lists):
                        tag = "tri_" + "+".join(names) + f"_wm{wm}_" + "_".join([f"w{n}{w}" for n, w in zip(names, ws)])
                        axis_sel = {"main": args.axis_dir_main}
                        axis_sel.update({n: discovered[n] for n in names})
                        weight_sel = {"main": [wm]}
                        for n, w in zip(names, ws):
                            weight_sel[n] = [w]
                        all_runs.append((tag, axis_sel, weight_sel))

        # ALL（main + 所有 usage；做笛卡尔积）
        if "all" in want and discovered:
            # 构建笛卡尔
            usage_names = sorted(discovered.keys())
            usage_weight_lists = [w_grids[n] for n in usage_names]
            for wm in w_main_grid:
                for combo in itertools.product(*usage_weight_lists):
                    tag = f"all_wm{wm}_" + "_".join([f"w{n}{w}" for n, w in zip(usage_names, combo)])
                    axis_sel = {"main": args.axis_dir_main}
                    axis_sel.update({n: discovered[n] for n in usage_names})
                    weight_sel = {"main": [wm]}
                    for n, w in zip(usage_names, combo):
                        weight_sel[n] = [w]
                    all_runs.append((tag, axis_sel, weight_sel))

        # 去重（以 tag 唯一）
        seen = set()
        uniq_runs = []
        for tag, amap, wsel in all_runs:
            if tag not in seen:
                uniq_runs.append((tag, amap, wsel))
                seen.add(tag)

        print(f"🧪 Planned runs: {len(uniq_runs)}")
        overview = []
        for tag, axis_map, weight_sel in uniq_runs:
            # 构造融合 loader（此处权重只有一个值，但写成循环便于扩展）
            for wm in weight_sel.get("main", [1.0]):
                # 对于有多个 usage 的 all 情况，我们已将每个 usage 只取单个值，直接融合
                run_name = sanitize(tag)
                run_dir = os.path.join(args.out_dir, f"auto_{run_name}")

                def loader_for_lid(lid, _axis_map=axis_map, _wm=wm):
                    acc = None

                    # main（如果存在才加）
                    if "main" in _axis_map:
                        u_main = load_axis(_axis_map["main"], lid, H)
                        if u_main is not None:
                            acc = (u_main * _wm) if acc is None else (acc + u_main * _wm)

                    # 其余 usage（都按各自 weight_sel[name][0] 取权重）
                    for name, d in _axis_map.items():
                        if name == "main":
                            continue
                        wlist = weight_sel.get(name, [1.0])
                        w = float(wlist[0])
                        u = load_axis(d, lid, H)
                        if u is not None:
                            acc = (u * w) if acc is None else (acc + u * w)

                    return None if acc is None else unit(acc)

                axis_loader = remap_loader(loader_for_lid)
                info = eval_one_run(run_dir, X_layers, y_true, axis_loader)
                if info:
                    row = {"run": run_name, **info}
                    overview.append(row)

        if overview:
            pd.DataFrame(overview).sort_values("best_acc", ascending=False)\
              .to_csv(os.path.join(args.out_dir, "auto_overview.csv"), index=False)
            print(f"✅ Wrote overview: {os.path.join(args.out_dir,'auto_overview.csv')}")
        return

    raise ValueError(f"Unknown mode: {args.mode}")

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--mode", choices=["main","usage","mix","auto"], required=True)

    # 模型 & 数据
    ap.add_argument("--model_id", required=True)
    ap.add_argument("--audio_json", required=True)
    ap.add_argument("--wav_key", default="wav_path")
    ap.add_argument("--label_key", default="label")

    # main/usage/mix 公用
    ap.add_argument("--only_layers", default=None, help="如 5-20 或 6,8,10")
    ap.add_argument("--out_dir", required=True)
    ap.add_argument("--device", default="cuda", choices=["cuda","cpu"])
    ap.add_argument("--batch_size", type=int, default=8)
    ap.add_argument("--sr", type=int, default=16000)

    # main
    ap.add_argument("--axis_dir_main", default=None)

    # usage
    ap.add_argument("--usage_dir", default=None)
    ap.add_argument("--usage_root", default=None)

    # mix（手动）
    ap.add_argument("--axis", action="append", default=[], help="重复：name=DIR")
    ap.add_argument("--w",    action="append", default=[], help="重复：name=WEIGHT")

    # auto（自动）
    ap.add_argument("--axis_root", action="append", default=[], help="重复：一个或多个 usage 轴根目录（脚本会遍历子目录）")
    ap.add_argument("--w_main_grid", default="1.0", help="如: 1.0,0.5,2.0")
    ap.add_argument("--w_grid", action="append", default=[], help="重复：name=v1,v2 例：tone=1.0,0.6")
    ap.add_argument("--combos", default="solo,pairwise,all", help="选择要跑的组合：solo,pairwise,all 的子集，用逗号分隔")
    args = ap.parse_args()
    main(args)
