#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import json
import torch
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm

"meta-llama/Meta-Llama-3-8B-Instruct"
"google/gemma-7b"
"mistralai/Mistral-7B-Instruct-v0.2"

MODELS = {
"Gemma-7B-IT": ("google/gemma-7b-it", 28, 3072),
    # "Mistral-7B": ("mistralai/Mistral-7B-Instruct-v0.2", 32, 4096),
    # "LLaMA-3-8B": ("meta-llama/Meta-Llama-3-8B-Instruct", 32, 4096),
}
USAGES = ["genre", "topic", "tone", "contextual"]
DATA_DIR = "data/processed/train"       
OUT_DIR = "outputs/usage_neurons_new"    
MAX_LEN = 128
BATCH_SIZE = 32
DEVICE = "cuda"

os.makedirs(OUT_DIR, exist_ok=True)



def load_labels(usage):
    path = os.path.join(DATA_DIR, f"{usage}_sentiment_dataset.json")
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    label_map = {"positive": 1, "negative": 0}  # 转成 0/1
    labels = np.array([label_map[d["label"]] for d in data if d["label"] in label_map])
    texts = [d["text"] for d in data if d["label"] in label_map]
    return texts, labels



def save_embeddings(model_name, model_id, num_layers, hidden_dim):
    print(f"\n🚀 {model_name}  embeddings")
    out_dir = os.path.join("data", model_name)
    os.makedirs(out_dir, exist_ok=True)

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model = AutoModel.from_pretrained(model_id, output_hidden_states=True).to(DEVICE)
    model.eval()

    for usage in USAGES:
        texts, labels = load_labels(usage)
        np.save(os.path.join(out_dir, f"{usage}_labels.npy"), labels)

        all_layers = None
        for i in tqdm(range(0, len(texts), BATCH_SIZE), desc=f"{model_name}-{usage}"):
            batch_texts = texts[i:i+BATCH_SIZE]
            inputs = tokenizer(batch_texts, return_tensors="pt", padding=True,
                               truncation=True, max_length=MAX_LEN).to(DEVICE)
            with torch.no_grad():
                outputs = model(**inputs)
                hidden_states = outputs.hidden_states  # list of [batch, seq_len, hidden_dim]

            
            batch_layers = [hs.mean(dim=1).cpu().numpy() for hs in hidden_states[1:]]

            if all_layers is None:
                all_layers = [bl for bl in batch_layers]
            else:
                for l in range(len(all_layers)):
                    all_layers[l] = np.vstack([all_layers[l], batch_layers[l]])

        
        for layer_id, arr in enumerate(all_layers, start=1):
            out_path = os.path.join(out_dir, f"{usage}_layer{layer_id}.npy")
            np.save(out_path, arr)
            print(f"✅ Saved {out_path}, shape={arr.shape}")



def compute_diffs_and_csv(model_name, num_layers, hidden_dim):
    print(f"\n📊 计算 {model_name} 的 diff & CSV")
    data_dir = os.path.join("data", model_name)
    model_out_dir = os.path.join(OUT_DIR, model_name)
    os.makedirs(model_out_dir, exist_ok=True)

    all_layers = []
    for layer in range(1, num_layers + 1):
        neuron_matrix = []
        ok = True
        for usage in USAGES:
            emb_path = os.path.join(data_dir, f"{usage}_layer{layer}.npy")
            label_path = os.path.join(data_dir, f"{usage}_labels.npy")
            if not (os.path.exists(emb_path) and os.path.exists(label_path)):
                ok = False
                break

            emb = np.load(emb_path)      # [N,H]
            labels = np.load(label_path) # [N]

            if emb.shape[0] != len(labels):
                print(f"❌ mismatch {usage} layer{layer}: emb={emb.shape[0]}, labels={len(labels)}")
                ok = False
                break

            if (labels == 1).sum() == 0 or (labels == 0).sum() == 0:
                print(f"⚠️ {usage} layer{layer} 没有正样本或负样本，跳过")
                ok = False
                break

            pos_mean = emb[labels == 1].mean(axis=0)
            neg_mean = emb[labels == 0].mean(axis=0)
            diff = (pos_mean - neg_mean).ravel()
            neuron_matrix.append(diff)

        if not ok or len(neuron_matrix) != len(USAGES):
            continue

        neuron_matrix = np.stack(neuron_matrix)  # [4,H]

        std = neuron_matrix.std(axis=0)
        diff_range = neuron_matrix.max(axis=0) - neuron_matrix.min(axis=0)
        polarity_flip = ((neuron_matrix > 0).any(axis=0)) & ((neuron_matrix < 0).any(axis=0))

        df = pd.DataFrame({
            "layer": layer,
            "neuron": np.arange(hidden_dim),
            "std": std,
            "range": diff_range,
            "polarity_flip": polarity_flip,
            **{f"diff_{u}": neuron_matrix[i] for i,u in enumerate(USAGES)}
        }).sort_values("range", ascending=False)

        csv_path = os.path.join(model_out_dir, f"layer{layer}_usage_neurons.csv")
        df.to_csv(csv_path, index=False)
        all_layers.append(df)

    if all_layers:
        df_all = pd.concat(all_layers, ignore_index=True)
        df_all.to_csv(os.path.join(model_out_dir, f"{model_name}_all_usage_neurons.csv"), index=False)
        print(f"🎉 {model_name} all layers CSV saved")



for model_name, (model_id, num_layers, hidden_dim) in MODELS.items():
    save_embeddings(model_name, model_id, num_layers, hidden_dim)
    compute_diffs_and_csv(model_name, num_layers, hidden_dim)
