#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Evaluate ONLY flip axes with usage-gated neurons.

- For each layer, load one trained flip axis:
  experiments/axis_from_new_neurons/{model_lower}_flip/sentiment_axis_L{L}.npy
- Apply usage gating on flip neurons by multiplying with sign vectors (±1/0).
  If no gating (∅ combo), use the raw flip axis.
- Stable axes are ignored entirely.
- One score per layer; select the best-performing layer for each dataset.

Outputs:
1) {out_dir}/{model}_{dataset}_combo_{combo}_perlayer.csv
2) {out_dir}/ALL_models_flip_usagegated_summary.csv
"""

import os, re, json, argparse, random, gc, itertools
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from transformers import AutoTokenizer, AutoModel

# ---------------- Basic config ----------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LEN = 128
BATCH_SIZE = 128
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# Paths
ID_DIR     = "outputs/flip_stable_neurons"   # will only use {model}/flip/L{XX}.txt
USAGE_DIR  = "outputs/usage_neurons_new"     # layer{L}_usage_neurons.csv

# Models (flip axis directories only)
MODELS = {
    "Gemma-7B-IT": {
        "model_id": "google/gemma-7b-it",
        "layers": 28,
        "single_axis_dir": "experiments/axis_from_new_neurons/gemma-7b-it_flip",
        "id_dir": os.path.join(ID_DIR, "Gemma-7B-IT"),
        "usage_csv_dir": os.path.join(USAGE_DIR, "Gemma-7B-IT"),
    },
    "LLaMA-3-8B": {
        "model_id": "meta-llama/Meta-Llama-3-8B-Instruct",
        "layers": 32,
        "single_axis_dir": "experiments/axis_from_new_neurons/llama-3-8b_flip",
        "id_dir": os.path.join(ID_DIR, "LLaMA-3-8B"),
        "usage_csv_dir": os.path.join(USAGE_DIR, "LLaMA-3-8B"),
    },
    "Mistral-7B": {
        "model_id": "mistralai/Mistral-7B-Instruct-v0.2",
        "layers": 32,
        "single_axis_dir": "experiments/axis_from_new_neurons/mistral-7b_flip",
        "id_dir": os.path.join(ID_DIR, "Mistral-7B"),
        "usage_csv_dir": os.path.join(USAGE_DIR, "Mistral-7B"),
    },
}

# Datasets
DATASETS = ["AnimalsBeingBros","Confession","Cringe","Dialogue","OkCupid","twitter","imdb","sst5"]
TEST_FILES = {
    "Dialogue": "data/processed/test/dailydialog_emotion_filtered.json",
    "AnimalsBeingBros": "data/processed/test/AnimalsBeingBros_posneg.json",
    "Confession": "data/processed/test/confession_posneg.json",
    "Cringe": "data/processed/test/cringe_posneg.json",
    "OkCupid": "data/processed/test/OkCupid_posneg.json",
    "sst5": "data/processed/test/sst5_binary_phrases.txt",
    "twitter": None, "imdb": None
}

USAGES_ORDERED = ["genre","topic","tone","contextual"]
USAGES_SET = set(USAGES_ORDERED)

# ---------------- Dataset loading ----------------
def load_dataset_split(name):
    if name == "twitter":
        ds = load_dataset("tweet_eval", "sentiment", split="test")
        texts, labels = [], []
        for t, l in zip(ds["text"], ds["label"]):
            if l in (0, 2):  # 0=neg, 2=pos
                texts.append(t); labels.append(1 if l==2 else 0)
        return texts, np.array(labels)
    if name == "imdb":
        ds = load_dataset("imdb", split="test")
        return ds["text"], np.array(ds["label"])
    if name == "sst5":
        path = TEST_FILES[name]
        texts, labels, bad = [], [], 0
        pat = re.compile(r"^(?P<text>.+?)\s*[,\t ]+\s*(?P<label>positive|negative)\s*$", re.IGNORECASE)
        with open(path, "r", encoding="utf-8", errors="ignore") as f:
            for line in f:
                s = line.strip()
                if not s: continue
                m = pat.match(s)
                if not m: bad += 1; continue
                texts.append(m.group("text").strip())
                labels.append(1 if m.group("label").lower()=="positive" else 0)
        if bad: print(f"[sst5] skipped {bad} malformed lines")
        return texts, np.array(labels)

    path = TEST_FILES[name]
    with open(path, "r", encoding="utf-8") as f: data = json.load(f)
    texts, labels = [], []
    for j in data:
        lbl = j.get("label") or j.get("valence") or j.get("sentiment") or j.get("emotion")
        if isinstance(lbl, str):
            if lbl.lower() not in {"positive","negative"}: continue
            y = 1 if lbl.lower()=="positive" else 0
        elif isinstance(lbl, int):
            y = lbl
        else:
            continue
        text = j.get("text") or j.get("utterance")
        if text: texts.append(text); labels.append(y)
    return texts, np.array(labels)

@torch.inference_mode()
def encode_all_layers(texts, tokenizer, model, num_layers):
    n = len(texts); H = model.config.hidden_size
    all_layers = [np.empty((n, H), dtype=np.float32) for _ in range(num_layers)]
    for i in tqdm(range(0, n, BATCH_SIZE), desc="Encoding"):
        batch = texts[i:i+BATCH_SIZE]
        inputs = tokenizer(batch, return_tensors="pt", padding="max_length",
                           truncation=True, max_length=MAX_LEN).to(model.device)
        out = model(**inputs, output_hidden_states=True)
        hs = out.hidden_states
        offset = len(hs) - num_layers
        mask = inputs["attention_mask"].unsqueeze(-1).float()
        denom = mask.sum(1).clamp_min(1e-6)
        for L in range(1, num_layers+1):
            idx = (L-1) + offset
            pooled = (hs[idx] * mask).sum(1) / denom
            all_layers[L-1][i:i+BATCH_SIZE] = pooled.cpu().numpy()
    return all_layers

# ---------------- Axes & usage gating ----------------
def load_flip_axis(dir_path, L):
    p = os.path.join(dir_path, f"sentiment_axis_L{L}.npy")
    if not os.path.exists(p): return None
    v = np.load(p); v = v if v.ndim == 1 else v[0]
    n = np.linalg.norm(v)
    return v / (n + 1e-12) if n > 0 else None

def _read_id_file(path):
    if not os.path.exists(path): return []
    out=[]
    with open(path,"r") as f:
        for x in f:
            x=x.strip()
            if x and x.lstrip("-").isdigit(): out.append(int(x))
    return out

def load_flip_ids(id_dir, L):
    path = os.path.join(id_dir, "flip", f"L{L:02d}.txt")
    return sorted([i for i in _read_id_file(path) if i>=0])

def load_usage_sign_vector(usage_csv_dir, L, usage, tau=1e-3):
    csv = os.path.join(usage_csv_dir, f"layer{L}_usage_neurons.csv")
    if not os.path.exists(csv): return None
    df = pd.read_csv(csv); col=f"diff_{usage}"
    if col not in df.columns: return None
    diff = df[col].to_numpy()
    sign = np.zeros_like(diff, dtype=np.float32)
    sign[diff > +tau] = +1.0
    sign[diff < -tau] = -1.0
    return sign

def fuse_multi_usage_sign(usage_csv_dir, L, usages, tau=1e-3, mode="mean_diff"):
    csv = os.path.join(usage_csv_dir, f"layer{L}_usage_neurons.csv")
    if not os.path.exists(csv): return None
    df = pd.read_csv(csv)

    diffs, signs = [], []
    for u in usages:
        col = f"diff_{u}"
        if col not in df.columns: continue
        d = df[col].to_numpy()
        diffs.append(d)
        s = np.zeros_like(d, dtype=np.float32)
        s[d > +tau] = +1.0
        s[d < -tau] = -1.0
        signs.append(s)

    if not diffs: return None
    diffs = np.stack(diffs, 0); signs = np.stack(signs, 0)

    if mode == "mean_diff":
        avg = diffs.mean(0)
        sign = np.zeros_like(avg, dtype=np.float32)
        sign[avg > +tau] = +1.0; sign[avg < -tau] = -1.0
        return sign
    if mode == "mean_sign":
        avg = signs.mean(0)
        sign = np.zeros_like(avg, dtype=np.float32)
        sign[avg > 0.0] = +1.0; sign[avg < 0.0] = -1.0
        return sign
    if mode == "vote":
        v = signs.sum(0)
        sign = np.zeros_like(v, dtype=np.float32)
        sign[v > 0] = +1.0; sign[v < 0] = -1.0
        return sign
    if mode == "maxabs":
        idx = np.argmax(np.abs(diffs), 0)
        h_idx = np.arange(diffs.shape[1])
        picked = diffs[idx, h_idx]
        sign = np.zeros_like(picked, dtype=np.float32)
        sign[picked > +tau] = +1.0; sign[picked < -tau] = -1.0
        return sign
    avg = diffs.mean(0)
    sign = np.zeros_like(avg, dtype=np.float32)
    sign[avg > +tau] = +1.0; sign[avg < -tau] = -1.0
    return sign

def apply_usage_gate_to_flip_axis(v_flip, flip_ids, sign_vec):
    """
    Apply usage gating on flip neurons only (±1/0).
    Non-flip dimensions remain unchanged.
    """
    if v_flip is None: return None
    if not flip_ids or sign_vec is None:
        return v_flip
    v = v_flip.copy()
    H = v.shape[0]
    for i in flip_ids:
        if 0 <= i < H:
            v[i] = v[i] * float(sign_vec[i])
    n = np.linalg.norm(v)
    return v / (n + 1e-12) if n > 0 else None

def stdonly(x):
    s = x.std()
    return x/(s + 1e-6)

# --------- All 16 usage combos (including empty) ----------
def all_usage_combos():
    combos = []
    U = USAGES_ORDERED
    for k in range(0, len(U)+1):
        for c in itertools.combinations(U, k):
            combos.append(list(c))
    return combos

def combo_name(usages):
    return "∅" if (not usages) else "+".join(usages)

# ---------------- Main ----------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--models", default="Gemma-7B-IT",
                    help="Comma-separated, available: " + ",".join(MODELS.keys()))
    ap.add_argument("--datasets", default=",".join(DATASETS))
    ap.add_argument("--flip_tau", type=float, default=0.0,
                    help="Threshold for usage gating; values within ±tau become 0")
    ap.add_argument("--multi_usage_mode", default="mean_diff",
                    choices=["mean_diff","mean_sign","vote","maxabs"],
                    help="Strategy for fusing multiple usage signs")
    ap.add_argument("--only_usages", default="",
                    help="Restrict to a subset of usages (e.g., genre,tone)")
    ap.add_argument("--use_stdonly", action="store_true",
                    help="Apply std-only normalization to scores for comparability")
    ap.add_argument("--out_dir", default="output/results_flip_axis_usage_gated")
    args = ap.parse_args()

    target_models = [m.strip() for m in args.models.split(",") if m.strip()]
    target_dsets  = [d.strip() for d in args.datasets.split(",") if d.strip()]
    os.makedirs(args.out_dir, exist_ok=True)

    allowed = set([u.strip() for u in args.only_usages.split(",") if u.strip()]) or USAGES_SET
    base_combos = [[u for u in c if u in allowed] for c in all_usage_combos()]
    usage_combos, seen = [], set()
    for c in base_combos:
        key = tuple(c)
        if key not in seen:
            seen.add(key); usage_combos.append(c)

    all_rows = []

    for model_name in target_models:
        if model_name not in MODELS:
            print(f"Skip unknown model {model_name}"); continue
        cfg = MODELS[model_name]

        print(f"\n🚀 Evaluating {model_name}")
        tok = AutoTokenizer.from_pretrained(cfg["model_id"])
        if tok.pad_token is None: tok.pad_token = tok.eos_token
        model = AutoModel.from_pretrained(cfg["model_id"]).to(DEVICE).eval()

        single_axis_dir = cfg["single_axis_dir"]
        id_dir = cfg["id_dir"]
        usage_csv_dir = cfg["usage_csv_dir"]

        for dname in target_dsets:
            texts, labels = load_dataset_split(dname)
            if len(texts) == 0:
                print("Skip empty dataset", dname); continue
            layers_repr = encode_all_layers(texts, tok, model, cfg["layers"])

            for usages in usage_combos:
                print(f"\n📊 {model_name} - {dname} | combo={combo_name(usages)} (flip-only, usage-gated)")
                best = {"acc": -1, "L": None}
                perL = []

                for L in range(1, cfg["layers"]+1):
                    v_flip = load_flip_axis(single_axis_dir, L)
                    if v_flip is None:
                        continue

                    v_use = v_flip
                    if len(usages) > 0:
                        flip_ids = load_flip_ids(id_dir, L)
                        if len(usages) == 1:
                            sign_vec = load_usage_sign_vector(usage_csv_dir, L, usages[0], tau=args.flip_tau)
                        else:
                            sign_vec = fuse_multi_usage_sign(
                                usage_csv_dir, L, usages,
                                tau=args.flip_tau, mode=args.multi_usage_mode
                            )
                        v_use = apply_usage_gate_to_flip_axis(v_flip, flip_ids, sign_vec)
                        if v_use is None:
                            continue

                    H_L = layers_repr[L-1]
                    scores = (H_L @ v_use).ravel()
                    if args.use_stdonly:
                        scores = stdonly(scores)
                    acc = accuracy_score(labels, (scores > 0).astype(int))
                    perL.append({"layer": f"L{L:02d}", "acc": acc})
                    if acc > best["acc"]:
                        best = {"acc": acc, "L": L}

                curve_csv = os.path.join(
                    args.out_dir,
                    f"{model_name}_{dname}_combo_{combo_name(usages).replace('+','-')}_perlayer.csv"
                )
                pd.DataFrame(perL).to_csv(curve_csv, index=False)

                if best["L"] is None:
                    print(f"❌ No valid layer for {model_name}-{dname}-{combo_name(usages)}")
                    continue

                print(f"✅ best = L{best['L']:02d}, acc={best['acc']:.6f}")
                all_rows.append({
                    "model": model_name,
                    "dataset": dname,
                    "usage_combo": combo_name(usages),
                    "multi_usage_mode": args.multi_usage_mode if len(usages)>1 else "n/a",
                    "flip_tau": args.flip_tau,
                    "use_stdonly": bool(args.use_stdonly),
                    "best_layer": f"L{best['L']:02d}",
                    "acc": best["acc"]
                })

            del layers_repr
            torch.cuda.empty_cache(); gc.collect()

        del model, tok
        torch.cuda.empty_cache(); gc.collect()
        print(f"🧹 Finished {model_name}")

    if all_rows:
        out_csv = os.path.join(args.out_dir, "ALL_models_flip_usagegated_summary.csv")
        pd.DataFrame(all_rows).to_csv(out_csv, index=False)
        print("\n📄 Saved:", out_csv)
        print(pd.DataFrame(all_rows).to_string(index=False))

if __name__ == "__main__":
    main()
