#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Evaluate usage-gated flip and stable components within the same layer.

- For each layer, load the main sentiment axis from experiments/exp_main_axis/{model}_main/sentiment_axis_L{L}.npy
- Split into stable and flip sub-axes using neuron ID files.
- Apply usage-gating (±1/0) on flip neurons based on usage-specific sign vectors.
- Compute scores: s_stable = H @ v_stable, s_flip = H @ v_flip_signed.
- Combine scores with alpha grid: s = α * s_flip + (1 - α) * s_stable.
- Search for the best alpha and best layer for each dataset/usage-combo.
- Supports multiple usage combos (16 total, including empty set).

Outputs:
1) {out_dir}/{model}_{dataset}_combo_{usage_combo}_perlayer_alpha.csv
2) {out_dir}/ALL_models_16combos_singlelayer_summary.csv
"""

import os, re, json, argparse, random, gc, itertools
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from transformers import AutoTokenizer, AutoModel

# ---------------- Basic config ----------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LEN = 128
BATCH_SIZE = 128
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# Paths
MAIN_AXIS_DIR = "experiments/exp_main_axis"
ID_DIR        = "outputs/flip_stable_neurons"
USAGE_DIR     = "outputs/usage_neurons_new"

# Models
MODELS = {
    "Gemma-7B-IT": {
        "model_id": "google/gemma-7b-it",
        "axis_dir": os.path.join(MAIN_AXIS_DIR, "gemma-7b-it_main"),
        "id_dir": os.path.join(ID_DIR, "Gemma-7B-IT"),
        "usage_csv_dir": os.path.join(USAGE_DIR, "Gemma-7B-IT"),
        "layers": 28
    },
    # Add other models if needed (LLaMA, Mistral, etc.)
}

# Datasets
DATASETS = ["AnimalsBeingBros","Confession","Cringe","Dialogue","OkCupid","twitter","imdb","sst5"]
TEST_FILES = {
    "Dialogue": "data/processed/test/dailydialog_emotion_filtered.json",
    "AnimalsBeingBros": "data/processed/test/AnimalsBeingBros_posneg.json",
    "Confession": "data/processed/test/confession_posneg.json",
    "Cringe": "data/processed/test/cringe_posneg.json",
    "OkCupid": "data/processed/test/OkCupid_posneg.json",
    "sst5": "data/processed/test/sst5_binary_phrases.txt",
    "twitter": None,
    "imdb": None
}

# Usage categories
USAGES_ORDERED = ["genre","topic","tone","contextual"]
USAGES_SET = set(USAGES_ORDERED)

# ---------------- Dataset loading ----------------
def load_dataset_split(name):
    if name == "twitter":
        ds = load_dataset("tweet_eval", "sentiment", split="test")
        texts, labels = [], []
        for t, l in zip(ds["text"], ds["label"]):
            if l in (0, 2):
                texts.append(t); labels.append(1 if l==2 else 0)
        return texts, np.array(labels)
    elif name == "imdb":
        ds = load_dataset("imdb", split="test")
        return ds["text"], np.array(ds["label"])
    elif name == "sst5":
        path = TEST_FILES[name]
        texts, labels, bad = [], [], 0
        pat = re.compile(r"^(?P<text>.+?)\s*[,\t ]+\s*(?P<label>positive|negative)\s*$", re.IGNORECASE)
        with open(path, "r", encoding="utf-8", errors="ignore") as f:
            for line in f:
                s = line.strip()
                if not s: continue
                m = pat.match(s)
                if not m: bad += 1; continue
                texts.append(m.group("text").strip())
                labels.append(1 if m.group("label").lower()=="positive" else 0)
        if bad: print(f"[sst5] skipped {bad} malformed lines")
        return texts, np.array(labels)
    else:
        path = TEST_FILES[name]
        with open(path) as f: data = json.load(f)
        texts, labels = [], []
        for j in data:
            lbl = j.get("label") or j.get("valence") or j.get("sentiment") or j.get("emotion")
            if isinstance(lbl, str):
                if lbl.lower() not in {"positive","negative"}: continue
                y = 1 if lbl.lower()=="positive" else 0
            elif isinstance(lbl, int):
                y = lbl
            else:
                continue
            text = j.get("text") or j.get("utterance")
            if text: texts.append(text); labels.append(y)
        return texts, np.array(labels)

@torch.inference_mode()
def encode_all_layers(texts, tokenizer, model, num_layers):
    n = len(texts)
    H = model.config.hidden_size
    all_layers = [np.empty((n, H), dtype=np.float32) for _ in range(num_layers)]
    for i in tqdm(range(0, n, BATCH_SIZE), desc="Encoding"):
        batch = texts[i:i+BATCH_SIZE]
        inputs = tokenizer(batch, return_tensors="pt", padding="max_length",
                           truncation=True, max_length=MAX_LEN).to(model.device)
        out = model(**inputs, output_hidden_states=True)
        hs = out.hidden_states
        offset = len(hs) - num_layers
        mask = inputs["attention_mask"].unsqueeze(-1).float()
        denom = mask.sum(1).clamp_min(1e-6)
        for L in range(1, num_layers+1):
            idx = (L-1) + offset
            pooled = (hs[idx] * mask).sum(1) / denom
            all_layers[L-1][i:i+BATCH_SIZE] = pooled.cpu().numpy()
    return all_layers

# ---------------- Axis and usage utilities ----------------
def load_axis(axis_dir, L):
    p = os.path.join(axis_dir, f"sentiment_axis_L{L}.npy")
    if not os.path.exists(p): return None
    v = np.load(p); v = v if v.ndim==1 else v[0]
    n = np.linalg.norm(v); return v/n if n>0 else None

def _read_id_file(path):
    if not os.path.exists(path): return []
    ids=[]
    with open(path,"r") as f:
        for x in f:
            x=x.strip()
            if x and x.lstrip("-").isdigit(): ids.append(int(x))
    return ids

def load_neuron_ids(id_dir, subset, L):
    path = os.path.join(id_dir, subset, f"L{L:02d}.txt")
    return sorted([i for i in _read_id_file(path) if i>=0])

def build_axis_masked(v, ids, H):
    if not ids: return None
    mask = np.zeros(H, dtype=np.float32)
    for i in ids:
        if 0<=i<H: mask[i]=1.0
    vm = v*mask; n=np.linalg.norm(vm)
    return vm/(n+1e-6) if n>0 else None

def load_usage_sign_vector(usage_csv_dir, L, usage, tau=1e-3):
    csv = os.path.join(usage_csv_dir, f"layer{L}_usage_neurons.csv")
    if not os.path.exists(csv): return None
    df = pd.read_csv(csv); col=f"diff_{usage}"
    if col not in df.columns: return None
    diff = df[col].to_numpy()
    sign = np.zeros_like(diff, dtype=np.float32)
    sign[diff > +tau] = +1.0
    sign[diff < -tau] = -1.0
    return sign, diff

def fuse_multi_usage_sign(usage_csv_dir, L, usages, tau=1e-3, mode="mean_diff"):
    csv = os.path.join(usage_csv_dir, f"layer{L}_usage_neurons.csv")
    if not os.path.exists(csv): return None
    df = pd.read_csv(csv)

    diffs, signs = [], []
    for u in usages:
        col = f"diff_{u}"
        if col not in df.columns: continue
        d = df[col].to_numpy()
        diffs.append(d)
        s = np.zeros_like(d, dtype=np.float32)
        s[d > +tau] = +1.0
        s[d < -tau] = -1.0
        signs.append(s)

    if not diffs: return None
    diffs = np.stack(diffs, 0); signs = np.stack(signs, 0)

    if mode == "mean_diff":
        avg = diffs.mean(0)
        sign = np.zeros_like(avg, dtype=np.float32)
        sign[avg > +tau] = +1.0; sign[avg < -tau] = -1.0
        return sign
    if mode == "mean_sign":
        avg = signs.mean(0)
        sign = np.zeros_like(avg, dtype=np.float32)
        sign[avg > 0.0] = +1.0; sign[avg < 0.0] = -1.0
        return sign
    if mode == "vote":
        v = signs.sum(0)
        sign = np.zeros_like(v, dtype=np.float32)
        sign[v > 0] = +1.0; sign[v < 0] = -1.0
        return sign
    if mode == "maxabs":
        idx = np.argmax(np.abs(diffs), 0)
        h_idx = np.arange(diffs.shape[1])
        picked = diffs[idx, h_idx]
        sign = np.zeros_like(picked, dtype=np.float32)
        sign[picked > +tau] = +1.0; sign[picked < -tau] = -1.0
        return sign
    avg = diffs.mean(0)
    sign = np.zeros_like(avg, dtype=np.float32)
    sign[avg > +tau] = +1.0; sign[avg < -tau] = -1.0
    return sign

def build_axis_signed(v, flip_ids, sign_vec):
    if sign_vec is None or not flip_ids: return None
    mask = np.zeros_like(sign_vec, dtype=np.float32)
    for i in flip_ids:
        if 0<=i<len(mask): mask[i] = sign_vec[i]
    vm = v*mask; n=np.linalg.norm(vm)
    return vm/(n+1e-6) if n>0 else None

def stdonly(x):
    s = x.std()
    return x/(s + 1e-6)

# --------- All 16 usage combos (including empty set) ----------
def all_usage_combos():
    combos = []
    U = USAGES_ORDERED
    for k in range(0, len(U)+1):
        for c in itertools.combinations(U, k):
            combos.append(list(c))
    return combos

def combo_name(usages):
    return "∅" if (not usages) else "+".join(usages)

# ---------------- Main ----------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--models", default="Gemma-7B-IT")
    ap.add_argument("--datasets", default=",".join(DATASETS))
    ap.add_argument("--alpha_grid", default="0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8")
    ap.add_argument("--flip_sign", default="usage", choices=["none","usage"],
                    help="Whether to apply usage gating (±1/0) to flip neurons")
    ap.add_argument("--flip_tau", type=float, default=0,
                    help="Threshold for usage sign; values within ±tau become 0")
    ap.add_argument("--multi_usage_mode", default="mean_diff",
                    choices=["mean_diff","mean_sign","vote","maxabs"],
                    help="How to fuse multiple usage signs when a combo has >1 usage")
    ap.add_argument("--only_usages", default="",
                    help="Restrict to a subset of usages (comma-separated, e.g., genre,tone)")
    ap.add_argument("--use_stdonly", action="store_true",
                    help="Apply std-only normalization to flip/stable scores")
    ap.add_argument("--out_dir", default="output/results_all_singlelayer_16combos")
    args = ap.parse_args()

    alpha_grid = [float(x) for x in args.alpha_grid.split(",") if x.strip()]
    target_models = [m.strip() for m in args.models.split(",") if m.strip()]
    target_dsets  = [d.strip() for d in args.datasets.split(",") if d.strip()]
    os.makedirs(args.out_dir, exist_ok=True)

    allowed = set([u.strip() for u in args.only_usages.split(",") if u.strip()]) or USAGES_SET
    base_combos = [ [u for u in c if u in allowed] for c in all_usage_combos() ]
    uniq, seen = [], set()
    for c in base_combos:
        key = tuple(c)
        if key not in seen:
            seen.add(key); uniq.append(c)
    usage_combos = uniq

    all_rows = []

    for model_name in target_models:
        if model_name not in MODELS:
            print(f"Skip unknown model {model_name}"); continue
        cfg = MODELS[model_name]

        print(f"\n🚀 Evaluating {model_name}")
        tok = AutoTokenizer.from_pretrained(cfg["model_id"])
        if tok.pad_token is None: tok.pad_token = tok.eos_token
        model = AutoModel.from_pretrained(cfg["model_id"]).to(DEVICE).eval()
        H = model.config.hidden_size

        for dname in target_dsets:
            texts, labels = load_dataset_split(dname)
            if len(texts)==0:
                print("Skip empty dataset", dname); continue

            layers_repr = encode_all_layers(texts, tok, model, cfg["layers"])

            for usages in usage_combos:
                if not set(usages).issubset(USAGES_SET):
                    continue

                print(f"\n📊 {model_name} - {dname} | combo={combo_name(usages)} | flip_sign={args.flip_sign} | multi={args.multi_usage_mode}")
                best_combo = {"acc": -1, "L": None, "alpha": None}
                perL_rows = []

                for L in range(1, cfg["layers"]+1):
                    v_main = load_axis(cfg["axis_dir"], L)
                    if v_main is None: continue

                    stable_ids = load_neuron_ids(cfg["id_dir"], "stable", L)
                    v_stable = build_axis_masked(v_main, stable_ids, H)
                    if v_stable is None: continue

                    H_L = layers_repr[L-1]

                    if args.flip_sign == "none" or len(usages)==0:
                        s_flip = np.zeros(len(H_L), dtype=np.float32)
                    else:
                        flip_ids = load_neuron_ids(cfg["id_dir"], "flip", L)
                        if not flip_ids:
                            continue
                        if len(usages) == 1:
                            sign_vec, _ = load_usage_sign_vector(cfg["usage_csv_dir"], L, usages[0], tau=args.flip_tau)
                        else:
                            sign_vec = fuse_multi_usage_sign(cfg["usage_csv_dir"], L, usages, tau=args.flip_tau, mode=args.multi_usage_mode)
                        if sign_vec is None:
                            continue
                        v_flip_signed = build_axis_signed(v_main, flip_ids, sign_vec)
                        if v_flip_signed is None:
                            continue
                        s_flip = (H_L @ v_flip_signed).ravel()

                    s_stable = (H_L @ v_stable).ravel()

                    s_flip_use, s_stable_use = s_flip.copy(), s_stable.copy()
                    if args.use_stdonly:
                        s_flip_use   = stdonly(s_flip_use)
                        s_stable_use = stdonly(s_stable_use)

                    acc_stable = accuracy_score(labels, (s_stable_use > 0).astype(int))
                    acc_flip   = accuracy_score(labels, (s_flip_use   > 0).astype(int))
                    perL_rows.append({"layer": f"L{L:02d}", "alpha": 0.0, "acc": acc_stable, "mode": "stable"})
                    perL_rows.append({"layer": f"L{L:02d}", "alpha": 1.0, "acc": acc_flip,   "mode": "flip"})

                    for a in alpha_grid:
                        s = a*s_flip_use + (1.0-a)*s_stable_use
                        acc = accuracy_score(labels, (s > 0).astype(int))
                        perL_rows.append({"layer": f"L{L:02d}", "alpha": a, "acc": acc, "mode": "combo"})
                        if acc > best_combo["acc"]:
                            best_combo = {"acc": acc, "L": L, "alpha": a}

                if best_combo["L"] is None:
                    print(f"❌ No valid layer for {model_name}-{dname}-{combo_name(usages)}")
                    continue

                print(f"✅ {model_name}-{dname}-{combo_name(usages)}: best_same_layer = L{best_combo['L']:02d}, α={best_combo['alpha']:.2f}, acc={best_combo['acc']:.6f}")

                df_curve = pd.DataFrame(perL_rows)
                curve_csv = os.path.join(
                    args.out_dir,
                    f"{model_name}_{dname}_combo_{combo_name(usages).replace('+','-')}_perlayer_alpha.csv"
                )
                df_curve.to_csv(curve_csv, index=False)

                all_rows.append({
                    "model": model_name,
                    "dataset": dname,
                    "usage_combo": combo_name(usages),
                    "flip_sign": args.flip_sign,
                    "multi_usage_mode": args.multi_usage_mode if len(usages)>1 else "n/a",
                    "use_stdonly": bool(args.use_stdonly),
                    "best_layer": f"L{best_combo['
