#!/usr/bin/env python3
# test_qwen_audio_axis.py
# Single-axis or multi-axis weighted evaluation for Qwen2-Audio.
# - Backward compatible: if --combo is not provided, evaluate the single axis in --axis_dir.
# - If --combo is provided, combine axes by "dir:weight,dir:weight,...".
#   Each dir must contain sentiment_axis_L{lid}.npy for the layer being evaluated.
# - By default, each axis is unit-normalized before weighting/summing, and the final
#   combined vector is unit-normalized. Use --no_norm_each to disable per-axis normalization.

import os, json, argparse
import numpy as np
import torch
import librosa
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from transformers import AutoProcessor

try:
    from transformers import Qwen2AudioForConditionalGeneration
    HAS_QWEN2_CLASS = True
except Exception:
    HAS_QWEN2_CLASS = False

try:
    from transformers import AutoModel
    HAS_AUTOMODEL = True
except Exception:
    HAS_AUTOMODEL = False

torch.set_grad_enabled(False)

# -------------------- utils --------------------
def unit(x: np.ndarray) -> np.ndarray:
    n = np.linalg.norm(x)
    return x / (n + 1e-12)

def safe_auc(y, s):
    try:
        return roc_auc_score(y, s)
    except Exception:
        return float("nan")

def parse_combo(combo_str: str):
    """
    combo_str: 'dirA:1.0,dirB:0.5,dirC:-0.2'
    return: list of (path, weight)
    """
    items = []
    if not combo_str:
        return items
    for token in combo_str.split(","):
        token = token.strip()
        if not token:
            continue
        # It's rare to have colons inside a path; split once from the right.
        if ":" in token:
            path, w = token.rsplit(":", 1)
            try:
                items.append((path.strip(), float(w)))
            except ValueError:
                print(f"⚠️ Failed to parse weight, skipping: {token}")
        else:
            # If no weight is given, default to 1.0
            items.append((token, 1.0))
    return items

# -------------------- IO --------------------
def load_audio_list(jpath, wav_key="wav_path", label_key="label"):
    data = json.load(open(jpath, "r", encoding="utf-8"))
    ids, wavs, labels = [], [], []
    for d in data:
        p = d.get(wav_key); lab = d.get(label_key)
        if not p or lab is None: continue
        if isinstance(lab, str):
            lab = lab.lower()
            if lab not in {"positive", "negative"}: continue
            y = 1 if lab == "positive" else 0
        else:
            y = int(lab)
        if os.path.exists(p):
            wavs.append(p); labels.append(y)
            ids.append(d.get("utt_id", os.path.basename(p)))
    return ids, wavs, np.array(labels, dtype=int)

def _detect_audio_token(processor):
    tok = getattr(getattr(processor, "tokenizer", None), "audio_token", None)
    if tok: return tok
    for c in ["<|AUDIO|>", "<|audio|>", "<|Audio|>"]:
        try:
            tid = processor.tokenizer.convert_tokens_to_ids(c)
            if tid is not None and tid != processor.tokenizer.unk_token_id:
                return c
        except Exception:
            pass
    return "<|AUDIO|>"

def safe_get_hidden_states(outputs):
    for key in ("audio_hidden_states", "encoder_hidden_states", "hidden_states"):
        if hasattr(outputs, key):
            hs = getattr(outputs, key)
            if hs is not None: return hs
    lmo = getattr(outputs, "language_model_output", None)
    if lmo is not None and getattr(lmo, "hidden_states", None) is not None:
        return lmo.hidden_states
    raise RuntimeError("hidden_states not found")

def read_waves(paths, sr):
    waves = []
    for p in paths:
        try: y, _ = librosa.load(p, sr=sr, mono=True)
        except Exception: y = np.zeros(1, dtype=np.float32)
        waves.append(y)
    return waves

def load_qwen2(model_id, device="cuda"):
    processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
    model, err = None, None
    if HAS_QWEN2_CLASS:
        try:
            model = Qwen2AudioForConditionalGeneration.from_pretrained(
                model_id, trust_remote_code=True
            ).eval().to(device)
        except Exception as e:
            err = e
    if model is None and HAS_AUTOMODEL:
        model = AutoModel.from_pretrained(model_id, trust_remote_code=True).eval().to(device)
    if model is None:
        raise RuntimeError(f"Failed to load {model_id}: {err!r}")
    return processor, model

def batch_encode_dynamic(wav_paths, processor, model, batch_size=8, target_sr=16000, device="cuda"):
    N = len(wav_paths)
    embs = None; num_layers = None; hidden_size = None; start_idx = 0
    audio_tok = _detect_audio_token(processor)

    # Probe number of layers and hidden size
    y0, _ = librosa.load(wav_paths[0], sr=target_sr, mono=True)
    probe = processor(audio=[y0], text=[audio_tok], sampling_rate=target_sr,
                      return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        out0 = model(**probe, output_hidden_states=True)
    hs0 = safe_get_hidden_states(out0)
    if len(hs0) >= 2 and hs0[0].shape[-1] == hs0[1].shape[-1]:
        start_idx = 1; num_layers = len(hs0) - 1
    else:
        start_idx = 0; num_layers = len(hs0)
    hidden_size = hs0[start_idx].shape[-1]
    embs = [np.empty((N, hidden_size), dtype=np.float32) for _ in range(num_layers)]

    # Actual encoding
    for i in tqdm(range(0, N, batch_size), desc="🎧 Encode"):
        chunk = wav_paths[i:i+batch_size]
        waves = read_waves(chunk, target_sr)
        inputs = processor(audio=waves, text=[audio_tok]*len(waves),
                           sampling_rate=target_sr, return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
        hs = safe_get_hidden_states(outputs)

        # Use attention mask for time pooling
        attn = None
        for k in ("attention_mask","audio_attention_mask","input_attention_mask"):
            if k in inputs and inputs[k] is not None:
                attn = inputs[k]; break
        mask = None
        if attn is not None:
            if attn.dim()==3:
                if attn.size(1)==1: attn = attn.squeeze(1)
                elif attn.size(2)==1: attn = attn.squeeze(2)
            if attn.dim()==2:
                mask = attn.unsqueeze(-1).float()

        for l in range(num_layers):
            h = hs[start_idx + l]
            if h.dim()==3:
                if mask is not None and mask.shape[:2]==h.shape[:2]:
                    pooled = (h*mask).sum(1) / mask.sum(1).clamp(min=1e-6)
                else:
                    pooled = h.mean(1)
            else:
                pooled = h.reshape(h.size(0), -1)
            embs[l][i:i+len(chunk)] = pooled.detach().cpu().numpy().astype(np.float32)

        del inputs, outputs, hs, waves
        if device == "cuda": torch.cuda.empty_cache()
    return embs, num_layers, hidden_size

# -------------------- axis loading & combination --------------------
def load_axis_from_dir_for_layer(dir_path: str, lid: int):
    """
    Load sentiment_axis_L{lid}.npy from a directory.
    """
    cand = os.path.join(dir_path, f"sentiment_axis_L{lid}.npy")
    if os.path.exists(cand):
        return np.load(cand)
    return None

def build_combined_axis(lid: int, H: int, axis_dir: str, combo_list, norm_each: bool, verbose=True):
    """
    Build the combined axis for layer `lid`:
    - If `combo_list` is empty: load a single axis from `axis_dir`.
    - If `combo_list` is non-empty: for each (dir, w), load the layer axis,
      (optionally unit-normalize) then sum up weighted vectors and finally unit-normalize.
    """
    vecs = []

    if combo_list:  # multi-axis combination
        for path, w in combo_list:
            v = load_axis_from_dir_for_layer(path, lid)
            if v is None:
                if verbose: print(f"⚠️ L{lid}: axis file not found in {path}")
                continue
            if v.shape[-1] != H:
                if verbose: print(f"⚠️ L{lid}: dim mismatch {v.shape[-1]} vs {H} @ {path}")
                continue
            v = unit(v) if norm_each else v
            vecs.append(w * v)
        if not vecs:
            return None
        u = np.sum(vecs, axis=0)
        return unit(u)  # final normalization
    else:  # single axis
        p = os.path.join(axis_dir, f"sentiment_axis_L{lid}.npy")
        if not os.path.exists(p):
            if verbose: print(f"⚠️ L{lid}: missing axis {p}")
            return None
        u = np.load(p)
        if u.shape[-1] != H:
            if verbose: print(f"⚠️ L{lid}: dim mismatch {u.shape[-1]} vs {H}")
            return None
        return unit(u)

# -------------------- main --------------------
def main(args):
    os.makedirs(args.out_dir, exist_ok=True)

    # Load data
    utt_ids, wavs, y_true = load_audio_list(args.audio_json, args.wav_key, args.label_key)
    assert len(wavs) > 0, "No audio found."

    # Load model
    processor, model = load_qwen2(args.model_id, args.device)

    # Encode per-layer embeddings
    embs_by_layer, L, H = batch_encode_dynamic(
        wavs, processor, model,
        batch_size=args.batch_size, target_sr=args.sr, device=args.device
    )
    print(f"🔍 Derived structure: layers={L}, hidden_size={H}")

    combo_list = parse_combo(args.combo)
    if combo_list:
        print("🔗 Using axis combination:")
        for p, w in combo_list:
            print(f"   - {p} * {w}")
    else:
        print(f"🧭 Using single axis_dir: {args.axis_dir}")

    # Select layers
    layers = list(range(1, L+1))
    if args.only_layers:
        layers = []
        for tok in args.only_layers.split(","):
            tok = tok.strip()
            if not tok: continue
            if "-" in tok:
                a, b = tok.split("-", 1)
                a, b = int(a), int(b)
                layers.extend(list(range(a, b+1)))
            else:
                layers.append(int(tok))
        # deduplicate & validate range
        layers = sorted(set([l for l in layers if 1 <= l <= L]))

    rows = []
    for lid in layers:
        # Build axis (single or combined) for this layer
        u = build_combined_axis(
            lid, H, args.axis_dir, combo_list,
            norm_each=(not args.no_norm_each),
            verbose=True
        )
        if u is None:
            continue

        # Layer representation
        X = embs_by_layer[lid-1]

        # Scoring and metrics (sign auto-alignment)
        scores = X @ u
        preds  = (scores > 0).astype(int)
        acc = accuracy_score(y_true, preds)
        f1  = f1_score(y_true, preds)
        auc = safe_auc(y_true, scores)
        # If accuracy < 0.5, flip sign for reporting consistency.
        if acc < 0.5:
            scores = -scores
            preds  = 1 - preds
            acc = accuracy_score(y_true, preds)
            f1  = f1_score(y_true, preds)
            auc = safe_auc(y_true, scores)

        tag = "combo" if combo_list else "single"
        pd.DataFrame({
            "utt_id": utt_ids,
            "wav_path": wavs,
            "score": scores,
            "pred_label": preds,
            "true_label": y_true,
            "correct": (preds == y_true).astype(int),
            "layer": lid,
            "mode": tag
        }).to_csv(os.path.join(args.out_dir, f"proj_L{lid}_{tag}.csv"), index=False)

        rows.append({"layer": lid, "mode": tag, "acc": acc, "f1": f1, "auroc": auc})
        print(f"  L{lid:02d} [{tag}] | Acc {acc:.4f} | F1 {f1:.4f} | AUROC {auc:.4f}")

    if rows:
        df = pd.DataFrame(rows).sort_values("layer")
        df.to_csv(os.path.join(args.out_dir, "summary.csv"), index=False)
        best = df.loc[df["acc"].idxmax()]
        print(f"\n✅ Best L{int(best.layer)} [{best.mode}] | Acc {best.acc:.4f} | F1 {best.f1:.4f} | AUROC {best.auroc:.4f}")
        print(f"✅ Saved summary to {os.path.join(args.out_dir,'summary.csv')}")
    else:
        print("No results. (No axis files found or all mismatched dims.)")

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--model_id", required=True)
    ap.add_argument("--axis_dir", required=True,
                    help="Directory containing sentiment_axis_L{lid}.npy (used in single-axis mode)")
    ap.add_argument("--combo", default=None,
                    help='Multi-axis weighted combo: "dirA:wA,dirB:wB,dirC:wC"; each dir must contain sentiment_axis_L{lid}.npy')
    ap.add_argument("--only_layers", default=None,
                    help="Evaluate specific layers only, e.g., '6,14,24' or '6-12,20'")
    ap.add_argument("--no_norm_each", action="store_true",
                    help="Disable per-axis unit normalization (final combined vector is still normalized)")
    ap.add_argument("--audio_json", required=True)
    ap.add_argument("--wav_key", default="wav_path")
    ap.add_argument("--label_key", default="label")
    ap.add_argument("--out_dir", required=True)
    ap.add_argument("--device", default="cuda", choices=["cuda","cpu"])
    ap.add_argument("--batch_size", type=int, default=8)
    ap.add_argument("--sr", type=int, default=16000)
    args = ap.parse_args()
    main(args)
