#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os, json, argparse, random, gc, re
import numpy as np, pandas as pd
import torch
from tqdm import tqdm
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from transformers import AutoTokenizer, AutoModel

# ---------------- Basic configuration ----------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LEN = 128
BATCH_SIZE = 128
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# === Base directory for your new axes (named as model_lower + mode) ===
AXIS_BASE_DIR = "experiments/axis_from_new_neurons"

# Model configuration
MODELS = {
    # "Gemma-7B": {
    #     "model_id": "google/gemma-7b",
    #     "layers": 28
    # },
    # "LLaMA-3-8B": {
    #     "model_id": "meta-llama/Meta-Llama-3-8B-Instruct",
    #     "layers": 32
    # },
    # "Mistral-7B": {
    #     "model_id": "mistralai/Mistral-7B-Instruct-v0.2",
    #     "layers": 32
    # },
    "Gemma-7B-IT": {
        "model_id": "google/gemma-7b-it",
        "layers": 28
    }
}

DATASETS = ["twitter", "Dialogue", "AnimalsBeingBros", "Confession", "Cringe", "OkCupid", "imdb", "sst5"]

TEST_FILES = {
    "Dialogue": "data/processed/test/dailydialog_emotion_filtered.json",
    "AnimalsBeingBros": "data/processed/test/AnimalsBeingBros_posneg.json",
    "Confession": "data/processed/test/confession_posneg.json",
    "Cringe": "data/processed/test/cringe_posneg.json",
    "OkCupid": "data/processed/test/OkCupid_posneg.json",
    "sst5": "data/processed/test/sst5_binary_phrases.txt",
    "twitter": None,
    "imdb": None
}

# ---------------- Dataset loading ----------------
def load_dataset_split(name):
    if name == "twitter":
        ds = load_dataset("tweet_eval", "sentiment", split="test")
        texts, labels = [], []
        for t, l in zip(ds["text"], ds["label"]):
            if l in (0, 2):
                texts.append(t)
                labels.append(1 if l == 2 else 0)
        return texts, np.array(labels)

    elif name == "imdb":
        ds = load_dataset("imdb", split="test")
        return ds["text"], np.array(ds["label"])

    elif name == "sst5":
        path = TEST_FILES[name]
        texts, labels = [], []
        pat = re.compile(r"^(?P<text>.+?)\s*[,\t ]+\s*(?P<label>positive|negative)\s*$", re.IGNORECASE)
        bad = 0
        with open(path, "r", encoding="utf-8", errors="ignore") as f:
            for line in f:
                s = line.strip()
                if not s:
                    continue
                m = pat.match(s)
                if not m:
                    bad += 1
                    continue
                text = m.group("text").strip()
                lbl = m.group("label").lower()
                label = 1 if lbl == "positive" else 0
                if text:
                    texts.append(text); labels.append(label)
        if bad:
            print(f"[sst5] skipped {bad} lines that did not match '... <label>' pattern")
        return texts, np.array(labels)

    else:
        path = TEST_FILES[name]
        with open(path) as f:
            data = json.load(f)
        texts, labels = [], []
        for j in data:
            lbl = j.get("label") or j.get("valence") or j.get("sentiment") or j.get("emotion")
            if isinstance(lbl, str):
                if lbl.lower() not in {"positive", "negative"}: continue
                label = 1 if lbl.lower() == "positive" else 0
            elif isinstance(lbl, int):
                label = lbl
            else:
                continue
            text = j.get("text") or j.get("utterance")
            if text:
                texts.append(text); labels.append(label)
        return texts, np.array(labels)

@torch.inference_mode()
def encode_all_layers(texts, tokenizer, model, num_layers):
    hidden_size = model.config.hidden_size
    n = len(texts)
    all_layers = [np.empty((n, hidden_size), dtype=np.float32) for _ in range(num_layers)]
    for i in tqdm(range(0, n, BATCH_SIZE), desc="🧠 Encoding"):
        batch = texts[i:i+BATCH_SIZE]
        inputs = tokenizer(batch, return_tensors="pt", padding="max_length",
                           truncation=True, max_length=MAX_LEN).to(model.device)
        out = model(**inputs, output_hidden_states=True)
        hs = out.hidden_states
        offset = len(hs) - num_layers  # align with different model depths
        mask = inputs["attention_mask"].unsqueeze(-1)
        denom = mask.sum(1).clamp_min(1)
        for L in range(1, num_layers+1):
            idx = (L - 1) + offset
            pooled = (hs[idx] * mask).sum(1) / denom
            all_layers[L-1][i:i+BATCH_SIZE] = pooled.cpu().numpy()
    return all_layers

# ---------- Load axis from "axis_from_new_neurons" ----------
def load_axis_new(axis_base_dir, model_key, mode_tag, L):
    """
    axis_base_dir/{model_lower}_{mode_tag}/sentiment_axis_L{L}.npy
    mode_tag ∈ {flip, stable, flip_stable}
    """
    model_lower = model_key.lower()
    subdir = f"{model_lower}_{mode_tag}"
    path = os.path.join(axis_base_dir, subdir, f"sentiment_axis_L{L}.npy")  # no leading zeros for L
    if not os.path.exists(path):
        return None
    v = np.load(path)
    v = v if v.ndim == 1 else v[0]
    n = np.linalg.norm(v)
    return v / n if n > 0 else None

# ---------------- Main ----------------
def main():
    ap = argparse.ArgumentParser()
    # Modes: flip, stable, flip_stable; default is all three
    ap.add_argument("--modes", default="flip,stable,flip_stable")
    ap.add_argument("--axis_dir", default=AXIS_BASE_DIR,
                    help="Path to experiments/axis_from_new_neurons")
    ap.add_argument("--out_dir", default="output/results_new_axes/g_it_stable")
    args = ap.parse_args()

    modes = [m.strip() for m in args.modes.split(",") if m.strip()]
    os.makedirs(args.out_dir, exist_ok=True)

    all_scores, best_rows = [], []

    for model_name, cfg in MODELS.items():
        print(f"\n🚀 Evaluating {model_name}")
        tok = AutoTokenizer.from_pretrained(cfg["model_id"])
        if tok.pad_token is None: tok.pad_token = tok.eos_token
        model = AutoModel.from_pretrained(cfg["model_id"]).to(DEVICE).eval()

        for dname in DATASETS:
            texts, labels = load_dataset_split(dname)
            if len(texts) == 0:
                continue
            print(f"\n📊 {model_name} - {dname}: {len(texts)} samples")

            layers = encode_all_layers(texts, tok, model, cfg["layers"])

            for mode in modes:
                # The directory tag should match your saved folders: flip / stable / flip_stable
                mode_tag = mode.replace("+", "_")
                best_acc, best_layer = 0.0, None

                for L in range(1, cfg["layers"]+1):
                    v = load_axis_new(args.axis_dir, model_name, mode_tag, L)
                    if v is None:
                        # Axis not available for this layer, skip
                        continue

                    scores = (layers[L-1] @ v).ravel()
                    pred = (scores > 0).astype(int)
                    acc = accuracy_score(labels, pred)

                    all_scores.append({
                        "model": model_name, "dataset": dname,
                        "mode": mode_tag, "layer": f"L{L}", "acc": acc
                    })
                    if acc > best_acc:
                        best_acc, best_layer = acc, f"L{L}"

                best_rows.append({
                    "model": model_name, "dataset": dname,
                    "mode": mode_tag, "best_acc": best_acc, "best_layer": best_layer
                })

        del model, tok
        torch.cuda.empty_cache()
        gc.collect()
        print(f"🧹 Finished {model_name}, GPU memory cleared")

    pd.DataFrame(all_scores).to_csv(os.path.join(args.out_dir, "all_axis_accuracy.csv"), index=False)
    pd.DataFrame(best_rows).to_csv(os.path.join(args.out_dir, "best_axis.csv"), index=False)

    print("\n📋 Best layers:")
    print(pd.DataFrame(best_rows).sort_values(
        by=["model","dataset","mode","best_acc"], ascending=[True,True,True,False]
    ))

if __name__ == "__main__":
    main()
