#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Evaluate sentiment using main axis masked by flip/stable neurons.
Supports four lambda strategies when combining flip & stable at score level:
  - grid  : sweep --lambda_list (original behavior)
  - csv   : read per-layer lambda from --lambda_csv (columns: model,layer,lambda_flip)
  - ratio : compute lambda* = Δμ_flip / Δμ_stable online on the current split
  - fixed : use a global constant --lambda_fixed

Score fusion (flip+stable):
  s = s_stable + lambda * s_flip

We keep sign alignment (no bias) so that positive score means "positive".
"""

import os, json, argparse, random, gc, re
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from transformers import AutoTokenizer, AutoModel

# ---------------- Basic config ----------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LEN = 128
BATCH_SIZE = 128
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# Paths (keep your originals)
MAIN_AXIS_DIR = "experiments/exp_main_axis"
ID_DIR        = "outputs/flip_stable_neurons"

MODELS = {
    "Gemma-7B": {
        "model_id": "google/gemma-7b",
        "axis_dir" : os.path.join(MAIN_AXIS_DIR, "gemma-7b_main"),
        "id_dir"   : os.path.join(ID_DIR, "Gemma-7B"),
        "layers"   : 28,
    },
    "LLaMA-3-8B": {
        "model_id": "meta-llama/Meta-Llama-3-8B-Instruct",
        "axis_dir" : os.path.join(MAIN_AXIS_DIR, "Meta-Llama-3-8B-Instruct_main"),
        "id_dir"   : os.path.join(ID_DIR, "LLaMA-3-8B"),
        "layers"   : 32,
    },
    "Mistral-7B": {
        "model_id": "mistralai/Mistral-7B-Instruct-v0.2",
        "axis_dir" : os.path.join(MAIN_AXIS_DIR, "Mistral-7B-Instruct-v0.2_main"),
        "id_dir"   : os.path.join(ID_DIR, "Mistral-7B"),
        "layers"   : 32,
    },
}

DATASETS = ["twitter", "Dialogue", "AnimalsBeingBros", "Confession", "Cringe", "OkCupid", "imdb", "sst5"]

TEST_FILES = {
    "Dialogue": "data/processed/test/dailydialog_emotion_filtered.json",
    "AnimalsBeingBros": "data/processed/test/AnimalsBeingBros_posneg.json",
    "Confession": "data/processed/test/confession_posneg.json",
    "Cringe": "data/processed/test/cringe_posneg.json",
    "OkCupid": "data/processed/test/OkCupid_posneg.json",
    "sst5": "data/processed/test/sst5_binary_phrases.txt",
    "twitter": None,
    "imdb": None,
}

# ---------------- Utils ----------------
def _read_id_file(path):
    if not os.path.exists(path):
        return []
    ids = []
    with open(path, "r") as f:
        for x in f:
            x = x.strip()
            if x and x.lstrip("-").isdigit():
                ids.append(int(x))
    return ids

def load_neuron_ids(id_dir, tag, L):
    """tag in {'flip','stable'}，return non-negative indices of that layer."""
    path = os.path.join(id_dir, tag, f"L{L:02d}.txt")
    return [i for i in _read_id_file(path) if i >= 0]

def mask_axis(v, ids, hidden_size):
    """Mask the main axis on given indices and L2-normalize."""
    mask = np.zeros(hidden_size, dtype=np.float32)
    ids = [i for i in ids if 0 <= i < hidden_size]
    if not ids:
        return None
    mask[ids] = 1.0
    vm = v * mask
    n = np.linalg.norm(vm)
    return vm / n if n > 0 else None

def sign_align(v_unit, H, y):
    """Flip vector sign if mean(pos) - mean(neg) < 0. No bias involved."""
    if v_unit is None:
        return None, 0.0
    z = H @ v_unit
    gap = float(z[y == 1].mean() - z[y == 0].mean())
    return (v_unit if gap >= 0 else -v_unit), (gap if gap >= 0 else -gap)

def cosine(u, v):
    if u is None or v is None:
        return np.nan
    return float(np.dot(u, v))  # both normalized

def delta_mu(scores, labels):
    """Δμ = mean(pos) - mean(neg)"""
    pos = scores[labels == 1]
    neg = scores[labels == 0]
    if len(pos) == 0 or len(neg) == 0:
        return 0.0
    return float(pos.mean() - neg.mean())

def load_lambda_table(csv_path):
    df = pd.read_csv(csv_path)
    if "lambda_flip" in df.columns:
        col_lambda = "lambda_flip"
    elif "lambda" in df.columns:
        col_lambda = "lambda"
    else:
        raise KeyError("CSV 必须包含 'lambda_flip' 或 'lambda' 列")

    table = {}
    for _, r in df.iterrows():
        model = str(r["model"]).strip()
        layer = str(r["layer"]).strip()
        lam   = float(r[col_lambda])
        table[(model, layer)] = lam
    return table


def lambda_by_ratio(H, vf, vs, labels, clip=(0.0, 2.0)):
    """λ* = Δμ_flip / Δμ_stable (with sign alignment)."""
    if vf is None or vs is None:
        return 0.0
    vf, _ = sign_align(vf, H, labels)
    vs, _ = sign_align(vs, H, labels)
    zf = (H @ vf).ravel()
    zs = (H @ vs).ravel()
    # (optional) z-score to stabilize:
    # zf = (zf - zf.mean()) / (zf.std() + 1e-8)
    # zs = (zs - zs.mean()) / (zs.std() + 1e-8)
    d_flip   = delta_mu(zf, labels)
    d_stable = delta_mu(zs, labels)
    lam = 0.0 if abs(d_stable) < 1e-8 else d_flip / d_stable
    if clip is not None:
        lam = float(np.clip(lam, clip[0], clip[1]))
    return lam

# ---------------- Data ----------------
def load_dataset_split(name):
    if name == "twitter":
        ds = load_dataset("tweet_eval", "sentiment", split="test")
        texts, labels = [], []
        for t, l in zip(ds["text"], ds["label"]):
            if l in (0, 2):  # drop neutral
                texts.append(t); labels.append(1 if l == 2 else 0)
        return texts, np.array(labels)

    elif name == "imdb":
        ds = load_dataset("imdb", split="test")
        return ds["text"], np.array(ds["label"])

    elif name == "sst5":
        path = TEST_FILES[name]
        texts, labels = [], []
        pat = re.compile(r"^(?P<text>.+?)\s*[,\t ]+\s*(?P<label>positive|negative)\s*$", re.IGNORECASE)
        with open(path, "r", encoding="utf-8", errors="ignore") as f:
            for line in f:
                s = line.strip()
                if not s:
                    continue
                m = pat.match(s)
                if not m:
                    continue
                text = m.group("text").strip()
                lbl = m.group("label").lower()
                label = 1 if lbl == "positive" else 0
                if text:
                    texts.append(text); labels.append(label)
        return texts, np.array(labels)

    else:
        path = TEST_FILES[name]
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
        texts, labels = [], []
        for j in data:
            lbl = j.get("label") or j.get("valence") or j.get("sentiment") or j.get("emotion")
            if isinstance(lbl, str):
                if lbl.lower() not in {"positive", "negative"}:
                    continue
                label = 1 if lbl.lower() == "positive" else 0
            elif isinstance(lbl, int):
                label = lbl
            else:
                continue
            text = j.get("text") or j.get("utterance")
            if text:
                texts.append(text); labels.append(label)
        return texts, np.array(labels)

@torch.inference_mode()
def encode_all_layers(texts, tokenizer, model, num_layers):
    hidden_size = model.config.hidden_size
    n = len(texts)
    all_layers = [np.empty((n, hidden_size), dtype=np.float32) for _ in range(num_layers)]
    for i in tqdm(range(0, n, BATCH_SIZE), desc="🧠 Encoding"):
        batch = texts[i:i+BATCH_SIZE]
        inputs = tokenizer(batch, return_tensors="pt", padding="max_length",
                           truncation=True, max_length=MAX_LEN).to(model.device)
        out = model(**inputs, output_hidden_states=True)
        hs = out.hidden_states
        offset = len(hs) - num_layers  # usually 1 to skip embeddings
        mask = inputs["attention_mask"].unsqueeze(-1)
        denom = mask.sum(1).clamp_min(1)
        for L in range(1, num_layers+1):
            idx = (L - 1) + offset
            pooled = (hs[idx] * mask).sum(1) / denom
            all_layers[L-1][i:i+BATCH_SIZE] = pooled.detach().cpu().numpy()
    return all_layers

def load_axis(axis_dir, L):
    path = os.path.join(axis_dir, f"sentiment_axis_L{L}.npy")  # no leading zero
    if not os.path.exists(path):
        print(f"❌ Missing axis {path}")
        return None
    v = np.load(path)
    v = v if v.ndim == 1 else v[0]
    n = np.linalg.norm(v)
    return v / n if n > 0 else None

# ---------------- Main ----------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--modes", default="flip,stable,flip+stable",
                    help="Options: flip,stable,flip+stable (comma separated)")
    ap.add_argument("--out_dir", default="output/results_main_mask_lambda")
    ap.add_argument("--inspect", action="store_true",
                    help="Print per-layer Δμ/cosine/|ids| diagnostics")
    ap.add_argument("--no_sign_align", action="store_true",
                    help="Disable sign alignment (enabled by default)")

    # Lambda strategies
    ap.add_argument("--lambda_mode", type=str, default="grid",
                    choices=["grid", "csv", "ratio", "fixed"],
                    help="grid: sweep --lambda_list; "
                         "csv: read layer-wise lambda from --lambda_csv; "
                         "ratio: lambda* = Δμ_flip/Δμ_stable online; "
                         "fixed: use --lambda_fixed")

    ap.add_argument("--lambda_list", type=str, default="0,0.1,0.25,0.5,1.0",
                    help="Only used when --lambda_mode=grid")

    ap.add_argument("--lambda_csv", type=str, default="",
                    help="Only used when --lambda_mode=csv; "
                         "CSV columns: model,layer,lambda_flip")

    ap.add_argument("--lambda_fixed", type=float, default=0.25,
                    help="Only used when --lambda_mode=fixed")

    args = ap.parse_args()

    modes = [m.strip() for m in args.modes.split(",") if m.strip()]
    lam_list = [float(x) for x in args.lambda_list.split(",")] if (args.lambda_mode == "grid" and "flip+stable" in modes) else []

    os.makedirs(args.out_dir, exist_ok=True)

    all_scores, best_rows = [], []

    # Load lambda table once if needed
    lambda_table = load_lambda_table(args.lambda_csv) if args.lambda_mode == "csv" else {}

    for model_name, cfg in MODELS.items():
        print(f"\n🚀 Evaluating {model_name}")
        tok = AutoTokenizer.from_pretrained(cfg["model_id"])
        if tok.pad_token is None:
            tok.pad_token = tok.eos_token
        model = AutoModel.from_pretrained(cfg["model_id"]).to(DEVICE).eval()
        hidden_size = model.config.hidden_size

        for dname in DATASETS:
            texts, labels = load_dataset_split(dname)
            if len(texts) == 0:
                continue
            print(f"\n📊 {model_name} - {dname}: {len(texts)} samples")

            layers = encode_all_layers(texts, tok, model, cfg["layers"])

            for mode in modes:
                # determine lambda scan list
                if mode != "flip+stable":
                    scan_lams = [None]  # not used
                else:
                    if args.lambda_mode == "grid":
                        scan_lams = lam_list
                    else:
                        scan_lams = [None]  # single pass; lam decided inside

                for lam in scan_lams:
                    best_acc, best_layer = 0.0, None

                    for L in range(1, cfg["layers"]+1):
                        v_main = load_axis(cfg["axis_dir"], L)
                        if v_main is None:
                            continue

                        H = layers[L-1]

                        if mode == "flip":
                            ids = load_neuron_ids(cfg["id_dir"], "flip", L)
                            v_masked = mask_axis(v_main, ids, hidden_size)
                            if v_masked is None:
                                continue
                            if not args.no_sign_align:
                                v_masked, _ = sign_align(v_masked, H, labels)
                            scores = (H @ v_masked).ravel()

                        elif mode == "stable":
                            ids = load_neuron_ids(cfg["id_dir"], "stable", L)
                            v_masked = mask_axis(v_main, ids, hidden_size)
                            if v_masked is None:
                                continue
                            if not args.no_sign_align:
                                v_masked, _ = sign_align(v_masked, H, labels)
                            scores = (H @ v_masked).ravel()

                        elif mode == "flip+stable":
                            ids_f = load_neuron_ids(cfg["id_dir"], "flip",   L)
                            ids_s = load_neuron_ids(cfg["id_dir"], "stable", L)
                            vf = mask_axis(v_main, ids_f, hidden_size)
                            vs = mask_axis(v_main, ids_s, hidden_size)
                            if vf is None or vs is None:
                                continue

                            # optional diagnostics
                            if args.inspect:
                                vf_chk, gap_f = sign_align(vf, H, labels)
                                vs_chk, gap_s = sign_align(vs, H, labels)
                                cos_fs = cosine(vf_chk, vs_chk)
                                print(f"[Inspect] {model_name} {dname} L{L:02d} | "
                                      f"k_flip={len(ids_f):4d}, k_stable={len(ids_s):4d} | "
                                      f"Δμ_flip={gap_f:.4f}, Δμ_stable={gap_s:.4f} | cos(f,s)={cos_fs:.3f}")

                            # sign alignment
                            if not args.no_sign_align:
                                vf, _ = sign_align(vf, H, labels)
                                vs, _ = sign_align(vs, H, labels)

                            # choose lambda
                            if args.lambda_mode == "grid":
                                lam_use = float(lam)
                            elif args.lambda_mode == "fixed":
                                lam_use = float(args.lambda_fixed)
                            elif args.lambda_mode == "csv":
                                key = (model_name, f"L{L:02d}")
                                lam_use = float(lambda_table.get(key, 0.0))
                            elif args.lambda_mode == "ratio":
                                lam_use = lambda_by_ratio(H, vf, vs, labels, clip=(0.0, 2.0))
                            else:
                                raise ValueError(f"Unknown lambda_mode: {args.lambda_mode}")

                            zf = (H @ vf).ravel()
                            zs = (H @ vs).ravel()
                            scores = zs + lam_use * zf

                        else:
                            continue

                        pred = (scores > 0).astype(int)
                        acc  = accuracy_score(labels, pred)

                        row = {
                            "model": model_name, "dataset": dname,
                            "mode": mode, "layer": f"L{L:02d}", "acc": acc
                        }
                        if mode == "flip+stable":
                            row["lambda_flip"] = lam_use
                        all_scores.append(row)

                        if acc > best_acc:
                            best_acc, best_layer = acc, f"L{L:02d}"

                    best = {
                        "model": model_name, "dataset": dname,
                        "mode": mode, "best_acc": best_acc, "best_layer": best_layer
                    }
                    if mode == "flip+stable":
                        best["lambda_flip"] = lam_use if scan_lams == [None] else None
                    best_rows.append(best)

        del model, tok
        torch.cuda.empty_cache(); gc.collect()
        print(f"🧹 Finished {model_name}, GPU memory cleared")

    # save
    df_all  = pd.DataFrame(all_scores)
    df_best = pd.DataFrame(best_rows)
    if "lambda_flip" not in df_all.columns:
        df_all["lambda_flip"] = np.nan
    if "lambda_flip" not in df_best.columns:
        df_best["lambda_flip"] = np.nan

    os.makedirs(args.out_dir, exist_ok=True)
    df_all.to_csv(os.path.join(args.out_dir, "all_axis_accuracy.csv"), index=False)
    df_best.to_csv(os.path.join(args.out_dir, "best_axis.csv"), index=False)

    print("\n📋 Best layers (top few):")
    print(df_best.sort_values(
        by=["model","dataset","mode","lambda_flip","best_acc"],
        ascending=[True,True,True,True,False]
    ).head(30))

if __name__ == "__main__":
    main()
