import numpy as np
import torch
import matplotlib.pyplot as plt

model_name_1 = "Llama-2-7b-hf"
model_name_2 = "Meta-Llama-3-8B"
model_name_3 = "Mistral-7B-v0.1"

mode_1 = "set_topk_input_zero"
mode_2 = "set_topk_output_zero"

save_path_1 = f"Exps/exp_mactivation_attnsink/{model_name_1}/{mode_1}.npy"
save_path_2 = f"Exps/exp_mactivation_attnsink/{model_name_1}/{mode_2}.npy"
save_path_3 = f"Exps/exp_mactivation_attnsink/{model_name_2}/{mode_1}.npy"
save_path_4 = f"Exps/exp_mactivation_attnsink/{model_name_2}/{mode_2}.npy"
save_path_5 = f"Exps/exp_mactivation_attnsink/{model_name_3}/{mode_1}.npy"
save_path_6 = f"Exps/exp_mactivation_attnsink/{model_name_3}/{mode_2}.npy"


data_1 = np.load(save_path_1)
data_2 = np.load(save_path_2)
data_3 = np.load(save_path_3)
data_4 = np.load(save_path_4)
data_5 = np.load(save_path_5)
data_6 = np.load(save_path_6)



data = {
  "Mistral-7B":{
    "Before LN": data_5[:,0],
    "After LN": data_6[:,0]
  },
  "Llama2-7B": {
    "Before LN": data_1[:,0],
    "After LN": data_2[:,0]
  },
  "Llama3-8B": {
    "Before LN": data_3[:,0],
    "After LN": data_4[:,0]
  }
}

T = np.array([1,2,3,4,4.5])

T_prime = np.array([1000, 2000, 3000, 4000, 4096])

# T = np.array([1,2,3,4,5])
# T_prime = np.array([1,2,3,4,5])

fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True)

markers = {
    "Before LN": "o",
    "After LN": "*",
}

for ax, (model, curves) in zip(axes, data.items()):
    for eps, y in curves.items():
        ax.plot(
            T, y,
            marker=markers[eps],
            linewidth=2,
            markersize=8,
            label=rf"{eps}"
        )

    ax.set_title(model, fontsize=18)
    ax.set_xlabel(r"$k$", fontsize=18)
    ax.set_xticks(T)
    ax.set_xticklabels(T_prime)
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=14)

axes[0].set_ylabel(r"Sink$_1^\varepsilon$ (%)", fontsize=18)

# plt.tight_layout()

save_path = f"imgs/exp_mactivation_attnsinks"
plt.savefig(
    save_path,
    dpi=300,                 
    bbox_inches="tight"      
)



