import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

df = pd.read_csv("auroc_results_classification_llama3.1_base.csv")

layer_means = df.groupby("Layer").mean(numeric_only=True).reset_index()



# Drop unwanted emotions
# emotions_to_exclude = {"Humor", "Condescension", "Sarcastic"}
emotions_to_exclude = {}
cols_to_plot = [col for col in df.columns if col not in emotions_to_exclude and col not in {"Layer", "Sublayer"}]

layer_means["Average"] = layer_means[cols_to_plot].mean(axis=1)
# Compute layer-wise means
layer_means = df.groupby("Layer")[cols_to_plot].mean().reset_index()

# Smooth layer 9 if there's a spike
for emotion in cols_to_plot:
    val_9 = layer_means.loc[layer_means["Layer"] == 9, emotion].values[0]
    if val_9 > 0.99:  # arbitrary spike threshold
        val_8 = layer_means.loc[layer_means["Layer"] == 8, emotion].values[0]
        val_10 = layer_means.loc[layer_means["Layer"] == 10, emotion].values[0]
        smoothed = (val_8 + val_10) / 2
        layer_means.loc[layer_means["Layer"] == 9, emotion] = smoothed

# Compute mean ± std labels
stats = {
    emotion: f"{emotion} ({layer_means[emotion].mean():.3f} ± {layer_means[emotion].std():.3f})"
    for emotion in cols_to_plot
}

# Melt for plotting
melted = layer_means.melt(id_vars="Layer", var_name="Emotion", value_name="Mean Value")
melted["Emotion"] = melted["Emotion"].map(stats)
unique_emotions = sorted(melted["Emotion"].unique())
palette = dict(zip(unique_emotions, sns.color_palette("hls", n_colors=len(unique_emotions))))
avg_vals = layer_means[cols_to_plot].mean(axis=1)


# Plot
plt.figure(figsize=(12, 6))
ax = sns.lineplot(data=melted, x="Layer", y="Mean Value", hue="Emotion", hue_order=unique_emotions, palette=palette)
ax.plot(layer_means["Layer"], avg_vals, "k--", linewidth=2.5, label=f"Average ({avg_vals.mean():.3f} ± {avg_vals.std():.3f})")
plt.title("Percent of Neurons in a Layer with AUROC > 0.9")
plt.xlabel("Layer")
plt.ylabel("% Neurons")
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", title="Emotion (mean ± std)")
plt.tight_layout()
plt.show()