#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import pandas as pd
import matplotlib.pyplot as plt

# === 配置 ===
DATA_DIRS = {
    "Mistral-7B": "outputs/usage_neurons_new/Mistral-7B",
    "LLaMA-3-8B": "outputs/usage_neurons_new/LLaMA-3-8B",
    #"Gemma-7B":   "outputs/usage_neurons_new/Gemma-7B",
}
# DATA_DIRS = {
#     "Gemma-7B-IT":   "outputs/usage_neurons_new/Gemma-7B-IT",
# }
USAGES = ["genre", "topic", "tone", "contextual"]
USAGES_PLOT = ["genre", "topic", "tone", "context"]
TOPK = 10
LAYER = 15
MODE = "flip_first"   # 可选: "flip_first" / "topk_first"
OUT_DIR = "outputs/usage_neurons_new/figures_layer15_final_modes"
os.makedirs(OUT_DIR, exist_ok=True)

plt.rcParams.update({
    "axes.linewidth": 1.5,
    "axes.edgecolor": "black",
    "xtick.direction": "in",
    "ytick.direction": "in",
    "legend.frameon": False,
    "font.size": 20,
})


def plot_flipped_topk(df, model_name, layer, topn=10, sort_by="range"):
    """先找翻转 neuron，再取 TopK"""
    df_flips = df[df["polarity_flip"] == True]
    if df_flips.empty:
        print(f"⚠️ {model_name} Layer {layer} 没有翻转 neuron")
        return None

    df_top = df_flips.sort_values(sort_by, ascending=False).head(topn)

    plt.figure(figsize=(6, 5))
    for _, row in df_top.iterrows():
        neuron_id = int(row["neuron"])
        diffs = [row[f"diff_{u}"] for u in USAGES]
        plt.plot(USAGES_PLOT, diffs, marker="o", linewidth=2,
                 label=f"Neuron {neuron_id}")

    plt.axhline(0, color="black", linewidth=1, linestyle="--")
    plt.grid(True, linestyle=(0, (6, 10)), linewidth=1.0, alpha=0.6)
    ax = plt.gca()
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    # plt.title(f"{model_name} - Layer {layer} (Top {topn} flipped neurons)")
    plt.title(f"Layer {layer} (Top {topn} flipped neurons)", pad=20)
    plt.xlabel("Usage Type")
    plt.ylabel("Mean(Pos) - Mean(Neg)")
    plt.legend().remove()
    plt.tight_layout()

    out_path = os.path.join(OUT_DIR, f"{model_name}_layer{layer}_flip_top{topn}.png")
    plt.savefig(out_path, dpi=400, bbox_inches="tight")
    plt.close()
    print(f"✅ Saved {out_path}")
    return out_path


def plot_topk_then_flip(df, model_name, layer, topn=10, sort_by="range"):
    """先取 TopK neuron，再标记哪些翻转 (红色)"""
    df_top = df.sort_values(sort_by, ascending=False).head(topn)

    plt.figure(figsize=(6, 5))
    for _, row in df_top.iterrows():
        neuron_id = int(row["neuron"])
        diffs = [row[f"diff_{u}"] for u in USAGES]
        color = "red" if row["polarity_flip"] else "gray"
        plt.plot(USAGES, diffs, marker="o", linewidth=2,
                 label=f"Neuron {neuron_id}", color=color)

    plt.axhline(0, color="black", linewidth=1, linestyle="--")
    plt.grid(True, linestyle=(0, (6, 10)), linewidth=1.0, alpha=0.6)
    ax = plt.gca()
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    plt.title(f"{model_name} - Layer {layer} (Top {topn}, flipped=red)")
    plt.xlabel("Usage Type")
    plt.ylabel("Mean(Pos) - Mean(Neg)")
    plt.legend().remove()
    plt.tight_layout()

    out_path = os.path.join(OUT_DIR, f"{model_name}_layer{layer}_top{topn}_with_flips.png")
    plt.savefig(out_path, dpi=400, bbox_inches="tight")
    plt.close()
    print(f"✅ Saved {out_path}")
    return out_path


# === 主流程 ===
for model_name, data_dir in DATA_DIRS.items():
    csv_path = os.path.join(data_dir, f"layer{LAYER}_usage_neurons.csv")
    if not os.path.exists(csv_path):
        print(f"⚠️ {model_name} 没有第 {LAYER} 层的文件")
        continue

    df = pd.read_csv(csv_path)

    if MODE == "flip_first":
        plot_flipped_topk(df, model_name, layer=LAYER, topn=TOPK, sort_by="range")
    elif MODE == "topk_first":
        plot_topk_then_flip(df, model_name, layer=LAYER, topn=TOPK, sort_by="range")
