#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Single-axis evaluation with usage-gated flip neurons.

- 每层只读取一条你训练好的“合并轴”（single axis）。
- 对 flip neurons 按 usage 符号(±1/0)做门控；stable neurons 不变。
- 支持 16 种 usage 组合（含空集“∅”）；多 usage 时可配置符号融合策略。
- 不做 alpha 搜索；每层一个分数，取最佳层作为该 dataset 的结果。

Outputs:
1) 每个 (model, dataset, usage_combo) 的 per-layer 曲线：
   {out_dir}/{model}_{dataset}_combo_{combo_name}_perlayer.csv
2) 汇总表（最佳层）：
   {out_dir}/ALL_models_singleaxis_usagegated_summary.csv
"""

import os, re, json, argparse, random, gc, itertools
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from transformers import AutoTokenizer, AutoModel

# ---------------- 基本配置 ----------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LEN = 128
BATCH_SIZE = 128
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# 路径（与你现有目录保持一致）
ID_DIR     = "outputs/flip_stable_neurons"      # flip/stable 的 neuron id，按 L{XX}.txt
USAGE_DIR  = "outputs/usage_neurons_new"        # 每层 usage csv: layer{L}_usage_neurons.csv

# 单轴（你新训练的“合并轴”）目录模板：
# 默认：experiments/axis_from_new_neurons/{model_lower}_combined/sentiment_axis_L{L}.npy
SINGLE_AXIS_BASE = "experiments/axis_from_new_neurons"

# 模型配置（可按需增减）
MODELS = {
    # 例：原生 Gemma-7B
    # "Gemma-7B": {
    #     "model_id": "google/gemma-7b",
    #     "layers": 28,
    #     "single_axis_dir": "experiments/axis_from_new_neurons/gemma-7b_combined",
    #     "id_dir": os.path.join(ID_DIR, "Gemma-7B"),
    #     "usage_csv_dir": os.path.join(USAGE_DIR, "Gemma-7B"),
    # },
    # 例：Llama / Mistral（如需）
    # "LLaMA-3-8B": {
    #     "model_id": "meta-llama/Meta-Llama-3-8B-Instruct",
    #     "layers": 32,
    #     "single_axis_dir": "experiments/axis_from_new_neurons/llama-3-8b_flip_stable",
    #     "id_dir": os.path.join(ID_DIR, "LLaMA-3-8B"),
    #     "usage_csv_dir": os.path.join(USAGE_DIR, "LLaMA-3-8B"),
    # },
    # "Mistral-7B": {
    #     "model_id": "mistralai/Mistral-7B-Instruct-v0.2",
    #     "layers": 32,
    #     "single_axis_dir": "experiments/axis_from_new_neurons/mistral-7b_flip_stable",
    #     "id_dir": os.path.join(ID_DIR, "Mistral-7B"),
    #     "usage_csv_dir": os.path.join(USAGE_DIR, "Mistral-7B"),
    # },
    # 你重点用的 IT 版
    "Gemma-7B-IT": {
        "model_id": "google/gemma-7b-it",
        "layers": 28,
        # 按你的保存路径改成你的实际目录
        "single_axis_dir": "experiments/axis_from_new_neurons/gemma-7b-it_flip_stable",
        "id_dir": os.path.join(ID_DIR, "Gemma-7B-IT"),
        "usage_csv_dir": os.path.join(USAGE_DIR, "Gemma-7B-IT"),
    },
}

# 数据集
DATASETS = ["AnimalsBeingBros","Confession","Cringe","Dialogue","OkCupid","twitter","imdb","sst5"]

TEST_FILES = {
    "Dialogue": "data/processed/test/dailydialog_emotion_filtered.json",
    "AnimalsBeingBros": "data/processed/test/AnimalsBeingBros_posneg.json",
    "Confession": "data/processed/test/confession_posneg.json",
    "Cringe": "data/processed/test/cringe_posneg.json",
    "OkCupid": "data/processed/test/OkCupid_posneg.json",
    "sst5": "data/processed/test/sst5_binary_phrases.txt",
    "twitter": None,
    "imdb": None
}

# usage 的固定顺序（生成 16 组合）
USAGES_ORDERED = ["genre","topic","tone","contextual"]
USAGES_SET = set(USAGES_ORDERED)

# ---------------- 数据加载 ----------------
def load_dataset_split(name):
    if name == "twitter":
        ds = load_dataset("tweet_eval", "sentiment", split="test")
        texts, labels = [], []
        for t, l in zip(ds["text"], ds["label"]):
            if l in (0, 2):
                texts.append(t); labels.append(1 if l==2 else 0)
        return texts, np.array(labels)

    if name == "imdb":
        ds = load_dataset("imdb", split="test")
        return ds["text"], np.array(ds["label"])

    if name == "sst5":
        path = TEST_FILES[name]
        texts, labels, bad = [], [], 0
        pat = re.compile(r"^(?P<text>.+?)\s*[,\t ]+\s*(?P<label>positive|negative)\s*$", re.IGNORECASE)
        with open(path, "r", encoding="utf-8", errors="ignore") as f:
            for line in f:
                s = line.strip()
                if not s: continue
                m = pat.match(s)
                if not m: bad += 1; continue
                texts.append(m.group("text").strip())
                labels.append(1 if m.group("label").lower()=="positive" else 0)
        if bad: print(f"[sst5] skipped {bad} malformed lines")
        return texts, np.array(labels)

    # 其它 json
    path = TEST_FILES[name]
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    texts, labels = [], []
    for j in data:
        lbl = j.get("label") or j.get("valence") or j.get("sentiment") or j.get("emotion")
        if isinstance(lbl, str):
            if lbl.lower() not in {"positive","negative"}: continue
            y = 1 if lbl.lower()=="positive" else 0
        elif isinstance(lbl, int):
            y = lbl
        else:
            continue
        text = j.get("text") or j.get("utterance")
        if text: texts.append(text); labels.append(y)
    return texts, np.array(labels)

@torch.inference_mode()
def encode_all_layers(texts, tokenizer, model, num_layers):
    n = len(texts)
    H = model.config.hidden_size
    all_layers = [np.empty((n, H), dtype=np.float32) for _ in range(num_layers)]
    for i in tqdm(range(0, n, BATCH_SIZE), desc="🧠 Encoding"):
        batch = texts[i:i+BATCH_SIZE]
        inputs = tokenizer(batch, return_tensors="pt", padding="max_length",
                           truncation=True, max_length=MAX_LEN).to(model.device)
        out = model(**inputs, output_hidden_states=True)
        hs = out.hidden_states
        offset = len(hs) - num_layers
        mask = inputs["attention_mask"].unsqueeze(-1).float()
        denom = mask.sum(1).clamp_min(1e-6)
        for L in range(1, num_layers+1):
            idx = (L-1) + offset
            pooled = (hs[idx] * mask).sum(1) / denom
            all_layers[L-1][i:i+BATCH_SIZE] = pooled.cpu().numpy()
    return all_layers

# ---------------- 读取“单轴” & usage 符号门控 ----------------
def load_single_axis(dir_path, L):
    p = os.path.join(dir_path, f"sentiment_axis_L{L}.npy")
    if not os.path.exists(p):
        return None
    v = np.load(p)
    v = v if v.ndim == 1 else v[0]
    n = np.linalg.norm(v)
    return v / (n + 1e-12) if n > 0 else None

def _read_id_file(path):
    if not os.path.exists(path): return []
    out=[]
    with open(path,"r") as f:
        for x in f:
            x=x.strip()
            if x and x.lstrip("-").isdigit():
                out.append(int(x))
    return out

def load_neuron_ids(id_dir, subset, L):
    path = os.path.join(id_dir, subset, f"L{L:02d}.txt")
    return sorted([i for i in _read_id_file(path) if i>=0])

def load_usage_sign_vector(usage_csv_dir, L, usage, tau=1e-3):
    csv = os.path.join(usage_csv_dir, f"layer{L}_usage_neurons.csv")
    if not os.path.exists(csv): return None
    df = pd.read_csv(csv); col=f"diff_{usage}"
    if col not in df.columns: return None
    diff = df[col].to_numpy()
    sign = np.zeros_like(diff, dtype=np.float32)
    sign[diff > +tau] = +1.0
    sign[diff < -tau] = -1.0
    return sign

def fuse_multi_usage_sign(usage_csv_dir, L, usages, tau=1e-3, mode="mean_diff"):
    csv = os.path.join(usage_csv_dir, f"layer{L}_usage_neurons.csv")
    if not os.path.exists(csv): return None
    df = pd.read_csv(csv)

    diffs, signs = [], []
    for u in usages:
        col = f"diff_{u}"
        if col not in df.columns: continue
        d = df[col].to_numpy()
        diffs.append(d)
        s = np.zeros_like(d, dtype=np.float32)
        s[d > +tau] = +1.0
        s[d < -tau] = -1.0
        signs.append(s)

    if not diffs: return None
    diffs = np.stack(diffs, 0); signs = np.stack(signs, 0)

    if mode == "mean_diff":
        avg = diffs.mean(0)
        sign = np.zeros_like(avg, dtype=np.float32)
        sign[avg > +tau] = +1.0; sign[avg < -tau] = -1.0
        return sign
    if mode == "mean_sign":
        avg = signs.mean(0)
        sign = np.zeros_like(avg, dtype=np.float32)
        sign[avg > 0.0] = +1.0; sign[avg < 0.0] = -1.0
        return sign
    if mode == "vote":
        v = signs.sum(0)
        sign = np.zeros_like(v, dtype=np.float32)
        sign[v > 0] = +1.0; sign[v < 0] = -1.0
        return sign
    if mode == "maxabs":
        idx = np.argmax(np.abs(diffs), 0)
        h_idx = np.arange(diffs.shape[1])
        picked = diffs[idx, h_idx]
        sign = np.zeros_like(picked, dtype=np.float32)
        sign[picked > +tau] = +1.0; sign[picked < -tau] = -1.0
        return sign
    # fallback
    avg = diffs.mean(0)
    sign = np.zeros_like(avg, dtype=np.float32)
    sign[avg > +tau] = +1.0; sign[avg < -tau] = -1.0
    return sign

def build_axis_with_flip_signs(v_base, flip_ids, sign_vec):
    """
    对 flip neurons 乘 usage 符号(±1/0)；stable neurons 保持原轴权重。
    然后整体 L2 归一化。
    """
    if v_base is None:
        return None
    if not flip_ids or sign_vec is None:
        return v_base

    v = v_base.copy()
    H = v.shape[0]
    for i in flip_ids:
        if 0 <= i < H:
            v[i] = v[i] * float(sign_vec[i])  # ±1 或 0
    n = np.linalg.norm(v)
    return v / (n + 1e-12) if n > 0 else None

def stdonly(x):
    s = x.std()
    return x/(s + 1e-6)

# --------- 生成 16 个 usage 组合（含空） ----------
def all_usage_combos():
    combos = []
    U = USAGES_ORDERED
    for k in range(0, len(U)+1):  # 0..4
        for c in itertools.combinations(U, k):
            combos.append(list(c))
    return combos  # 16 个

def combo_name(usages):
    return "∅" if (not usages) else "+".join(usages)

# ---------------- 主流程：单轴 + usage 门控 ----------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--models", default="Gemma-7B-IT",
                    help="逗号分隔，可选: " + ",".join(MODELS.keys()))
    ap.add_argument("--datasets", default=",".join(DATASETS))
    ap.add_argument("--flip_tau", type=float, default=0.0,
                    help="usage 符号阈值，小于阈值置 0")
    ap.add_argument("--multi_usage_mode", default="mean_diff",
                    choices=["mean_diff","mean_sign","vote","maxabs"],
                    help="多 usage 的符号融合策略")
    ap.add_argument("--only_usages", default="",
                    help="限制可用 usage 子集（如 genre,tone）；留空=四个全用")
    ap.add_argument("--use_stdonly", action="store_true",
                    help="对分数做 std-only 归一化（提升不同层对比的稳健性）")
    ap.add_argument("--out_dir", default="output/results_single_axis_usage_gated")
    args = ap.parse_args()

    target_models = [m.strip() for m in args.models.split(",") if m.strip()]
    target_dsets  = [d.strip() for d in args.datasets.split(",") if d.strip()]
    os.makedirs(args.out_dir, exist_ok=True)

    # 只用部分 usage 生成组合（去重）
    allowed = set([u.strip() for u in args.only_usages.split(",") if u.strip()]) or USAGES_SET
    base_combos = [[u for u in c if u in allowed] for c in all_usage_combos()]
    usage_combos, seen = [], set()
    for c in base_combos:
        key = tuple(c)
        if key not in seen:
            seen.add(key); usage_combos.append(c)

    all_rows = []

    for model_name in target_models:
        if model_name not in MODELS:
            print(f"skip unknown model {model_name}"); continue
        cfg = MODELS[model_name]

        print(f"\n🚀 Evaluating {model_name}")
        tok = AutoTokenizer.from_pretrained(cfg["model_id"])
        if tok.pad_token is None: tok.pad_token = tok.eos_token
        model = AutoModel.from_pretrained(cfg["model_id"]).to(DEVICE).eval()

        H = model.config.hidden_size
        single_axis_dir = cfg["single_axis_dir"]
        id_dir = cfg["id_dir"]
        usage_csv_dir = cfg["usage_csv_dir"]

        for dname in target_dsets:
            texts, labels = load_dataset_split(dname)
            if len(texts) == 0:
                print("skip empty", dname); continue
            layers_repr = encode_all_layers(texts, tok, model, cfg["layers"])

            for usages in usage_combos:
                print(f"\n📊 {model_name} - {dname} | combo={combo_name(usages)} (usage-gated flip)")
                best = {"acc": -1, "L": None}
                perL = []

                for L in range(1, cfg["layers"]+1):
                    # 读取你那条“唯一的合并轴”
                    v_base = load_single_axis(single_axis_dir, L)
                    if v_base is None:
                        continue

                    v_use = v_base
                    # 只有在 combo 非空时才应用 usage 符号门控
                    if len(usages) > 0:
                        flip_ids = load_neuron_ids(id_dir, "flip", L)
                        if len(usages) == 1:
                            sign_vec = load_usage_sign_vector(usage_csv_dir, L, usages[0], tau=args.flip_tau)
                        else:
                            sign_vec = fuse_multi_usage_sign(usage_csv_dir, L, usages, tau=args.flip_tau, mode=args.multi_usage_mode)
                        v_use = build_axis_with_flip_signs(v_base, flip_ids, sign_vec)

                        if v_use is None:
                            # 该层找不到轴或符号信息无效时跳过
                            continue

                    H_L = layers_repr[L-1]                 # [N, H]
                    scores = (H_L @ v_use).ravel()

                    if args.use_stdonly:
                        scores = stdonly(scores)

                    acc = accuracy_score(labels, (scores > 0).astype(int))
                    perL.append({"layer": f"L{L:02d}", "acc": acc})

                    if acc > best["acc"]:
                        best = {"acc": acc, "L": L}

                # 输出 per-layer 曲线
                curve_csv = os.path.join(
                    args.out_dir,
                    f"{model_name}_{dname}_combo_{combo_name(usages).replace('+','-')}_perlayer.csv"
                )
                pd.DataFrame(perL).to_csv(curve_csv, index=False)

                if best["L"] is None:
                    print(f"❌ No valid layer for {model_name}-{dname}-{combo_name(usages)}")
                    continue

                print(f"✅ best = L{best['L']:02d}, acc={best['acc']:.6f}")
                all_rows.append({
                    "model": model_name,
                    "dataset": dname,
                    "usage_combo": combo_name(usages),
                    "multi_usage_mode": args.multi_usage_mode if len(usages)>1 else "n/a",
                    "flip_tau": args.flip_tau,
                    "use_stdonly": bool(args.use_stdonly),
                    "best_layer": f"L{best['L']:02d}",
                    "acc": best["acc"]
                })

            del layers_repr
            torch.cuda.empty_cache(); gc.collect()

        del model, tok
        torch.cuda.empty_cache(); gc.collect()
        print(f"🧹 Finished {model_name}")

    if all_rows:
        out_csv = os.path.join(args.out_dir, "ALL_models_singleaxis_usagegated_summary.csv")
        pd.DataFrame(all_rows).to_csv(out_csv, index=False)
        print("\n📄 Saved:", out_csv)
        print(pd.DataFrame(all_rows).to_string(index=False))


if __name__ == "__main__":
    main()
