import numpy as np
import torch
import matplotlib.pyplot as plt

model_name = "Llama-2-7b-hf"
model_name = "Meta-Llama-3-8B"
model_name = "Mistral-7B-v0.1"

mode_1 = "set_swap"
mode_2 = "set_zero"
mode_3 = "set_topk_input_zero"
mode_4 = "set_x0_one_layer"

save_path_1 = f"Exps/exp_mactivation_attnsink/{model_name}/{mode_1}_activation_magnitude.npy"
save_path_2 = f"Exps/exp_mactivation_attnsink/{model_name}/{mode_2}_activation_magnitude.npy"
save_path_3 = f"Exps/exp_mactivation_attnsink/{model_name}/{mode_3}_activation_magnitude.npy"
save_path_4 = f"Exps/exp_mactivation_attnsink/{model_name}/{mode_4}_activation_magnitude.npy"


data_1 = torch.from_numpy(np.load(save_path_1))[0,0,:,0,:].norm(p=2,dim=-1)
data_2 = torch.from_numpy(np.load(save_path_2))[0,0,:,0,:].norm(p=2,dim=-1)
data_3 = torch.from_numpy(np.load(save_path_3))[0,0,:,0,:].norm(p=2,dim=-1)
data_4 = torch.from_numpy(np.load(save_path_4))[0,0,:,0,:].norm(p=2,dim=-1)



data = {
  "swap": data_1,
  "entire_zero":  data_2,
  "topk_zero": data_3,
  "scaling": data_4
}

T = np.array(range(32))
T_prime = np.array(range(0,32,4))


# T = np.array([1,2,3,4,5])
# T_prime = np.array([1,2,3,4,5])

fig, axes = plt.subplots(1, 4, figsize=(18, 5), sharey=False)
fig.suptitle(f"Mistral-7B", fontsize=40)

markers = {
    "swap": "o",
    "entire_zero": "o",
    "topk_zero": "o",
    "scaling": "o"
}

for ax, (mode, x) in zip(axes, data.items()):
    ax.plot(
            T, x,
            marker=markers[mode],
            linewidth=2,
            markersize=8,
            # label=rf"{mode}"
        )

    ax.set_title(mode, fontsize=18)
    ax.set_xlabel("Layer", fontsize=18)
    ax.set_xticks(T_prime)
    ax.set_xticklabels(T_prime)
    ax.grid(True, alpha=0.3)
    # ax.legend(fontsize=14)

axes[0].set_ylabel(r"$L_2$-norm", fontsize=18)

plt.tight_layout(rect=[0,0,1,0.95])

save_path = f"imgs/exp_mactivation_attnsinks"
plt.savefig(
    save_path,
    dpi=300,                 
    bbox_inches="tight"      
)



