#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# === 配置 ===
DATA_DIRS = {
    "Mistral-7B": ("data/mistral-7b", 32),
    "LLaMA-3-8B": ("data/llama-3-8b", 32),
    "Gemma-7B":   ("data/gemma-7b", 28),
}
USAGES = ["genre", "topic", "tone", "contextual"]
USAGES_PLOT = ["genre", "topic", "tone", "context"]
HIDDEN_DIM = 4096
TOPK = 20
OUT_DIR = "outputs/usage_neurons"

plt.rcParams.update({
    "axes.linewidth": 1.2,
    "axes.edgecolor": "black",
    "xtick.direction": "in",
    "ytick.direction": "in",
    "legend.frameon": False,
    "font.size": 12,
})

for model_name, (data_dir, num_layers) in DATA_DIRS.items():
    print(f"\n🚀 Processing {model_name}")
    fig_dir = os.path.join(OUT_DIR, model_name, "figures")
    os.makedirs(fig_dir, exist_ok=True)
    os.makedirs(os.path.join(OUT_DIR, model_name), exist_ok=True)

    all_layers = []
    for layer in range(1, num_layers + 1):
        neuron_matrix = []

        for usage in USAGES:
            
            diff_path = os.path.join(data_dir, f"{usage}_layer{layer}_diff.npy")
            emb_path  = os.path.join(data_dir, f"{usage}_layer{layer}.npy")
            label_path = os.path.join(data_dir, f"{usage}_labels.npy")

            
            if not os.path.exists(diff_path):
                if not (os.path.exists(emb_path) and os.path.exists(label_path)):
                    print(f"⛔ 缺失 {usage} 第{layer}层数据，跳过")
                    break
                emb = np.load(emb_path)      # [N,H]
                labels = np.load(label_path) # [N]
                pos_mean = emb[labels == 1].mean(axis=0)
                neg_mean = emb[labels == 0].mean(axis=0)
                diff = pos_mean - neg_mean   # [H]
                np.save(diff_path, diff)
            else:
                diff = np.load(diff_path)

            neuron_matrix.append(diff)

        if len(neuron_matrix) != len(USAGES):
            continue

        neuron_matrix = np.stack(neuron_matrix)  # [4, H]

        
        std = neuron_matrix.std(axis=0)
        diff_range = neuron_matrix.max(axis=0) - neuron_matrix.min(axis=0)
        polarity_flip = ((neuron_matrix > 0).any(axis=0)) & ((neuron_matrix < 0).any(axis=0))
        hidden_dim = neuron_matrix.shape[1]
        

       
        df.to_csv(os.path.join(OUT_DIR, model_name, f"layer{layer}_usage_neurons.csv"), index=False)
        all_layers.append(df)

        
        top_idx = df.head(TOPK)["neuron"].values
        top_values = neuron_matrix[:, top_idx]

        plt.figure(figsize=(12, 4))
        for i, nid in enumerate(top_idx):
            plt.plot(USAGES_PLOT, top_values[:, i], marker="o", linewidth=1.5, label=f"Neuron {nid}")
        plt.title(f"{model_name} - Layer {layer} (Top {TOPK})")
        plt.xlabel("Usage Type")
        plt.ylabel("Mean(Pos) - Mean(Neg)")
        plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=7)
        plt.grid(True, linestyle="--", alpha=0.6)
        plt.tight_layout()
        plt.savefig(os.path.join(fig_dir, f"layer{layer}_top{TOPK}.png"), dpi=300)
        plt.close()

        print(f"✅ {model_name} Layer {layer}")

    
    if all_layers:
        df_all = pd.concat(all_layers, ignore_index=True)
        df_all.to_csv(os.path.join(OUT_DIR, model_name, f"{model_name}_all_usage_neurons.csv"), index=False)
        print(f"🎉 {model_name} ")
