import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

data = {
    'layer': [8, 8, 8, 8, 8, 12, 12, 12, 12, 12, 16, 16, 16, 16, 16, 20, 20, 20, 20, 20, 24, 24, 24, 24, 24],
    'coefficient': [0.5, 1.0, 2.0, 3.0, 5.0, 0.5, 1.0, 2.0, 3.0, 5.0, 0.5, 1.0, 2.0, 3.0, 5.0, 0.5, 1.0, 2.0, 3.0, 5.0, 0.5, 1.0, 2.0, 3.0, 5.0],
    # 'accuracy': [0.7289, 0.7445, 0.7486, 0.7093, 0.5266, 0.7224, 0.7166, 0.7346, 0.6855, 0.2924, 0.7158, 0.7289, 0.7215, 0.7199, 0.6880, 0.7133, 0.7338, 0.7240, 0.7355, 0.7215, 0.7142, 0.7289, 0.7273, 0.7256, 0.7060]
    'accuracy': [0.6997, 0.7475, 0.7153, 0.6768, 0.4362, 0.7034, 0.7548, 0.7227, 0.6905, 0.4040, 0.6896, 0.7585, 0.7181, 0.7163, 0.6657, 0.6823, 0.7319, 0.7456, 0.7264, 0.7071, 0.6823, 0.7475, 0.7438, 0.7410, 0.7383]

}
df = pd.DataFrame(data)

heatmap_data = df.pivot(index='coefficient', columns='layer', values='accuracy')

ax = sns.heatmap(heatmap_data, annot=True, fmt=".4f", cmap="YlOrBr", cbar=True)
ax.collections[0].colorbar.set_label("Accuracy")

plt.xlabel("Layer")
plt.ylabel("Coefficient")
plt.title("MMLU-Med: Layer vs Coefficient")

plt.savefig("analysis/output/heatmap_mmlu.png", dpi=300, bbox_inches='tight')
plt.show()
