import numpy as np
import torch
import matplotlib.pyplot as plt

model_name_1 = "Llama-2-7b-hf"
model_name_2 = "Meta-Llama-3-8B"
model_name_3 = "Mistral-7B-v0.1"

mode = "set_swap"

save_path_1 = f"Exps/exp_mactivation_attnsink/{model_name_1}/{mode}.npy"
save_path_2 = f"Exps/exp_mactivation_attnsink/{model_name_2}/{mode}.npy"
save_path_3 = f"Exps/exp_mactivation_attnsink/{model_name_3}/{mode}.npy"


data_1 = np.load(save_path_1)
data_2 = np.load(save_path_2)
data_3 = np.load(save_path_3)



data = {
  "1":{
    "Mistral-7B": data_1[:,0],
    "Llama2-7B": data_2[:,0],
    "Llama3-8B": data_3[:,0]
  }
}

T = np.array([1,2,3,4,5])
T_prime = np.array([1,2,3,4,5])


fig, ax = plt.subplots(1, 1, figsize=(10, 5), sharey=True)

markers = {
    "Mistral-7B": "o",
    "Llama2-7B": "*",
    "Llama3-8B": "+"
}

model = "1"
curves = data["1"]
for eps, y in curves.items():
  ax.plot(
            T, y,
            marker=markers[eps],
            linewidth=2,
            markersize=8,
            label=rf"{eps}"
        )

# ax.set_title(, fontsize=18)
ax.set_xlabel(r"$k$", fontsize=18)
ax.set_xticks(T)
ax.set_xticklabels(T_prime)
ax.grid(True, alpha=0.3)
ax.legend(fontsize=14)

ax.set_ylabel(r"Sink$_1^\varepsilon$ (%)", fontsize=18)

# plt.tight_layout()

save_path = f"imgs/exp_mactivation_attnsinks"
plt.savefig(
    save_path,
    dpi=300,                 
    bbox_inches="tight"      
)



