import numpy as np
import matplotlib.pyplot as plt
import torch
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.gridspec import GridSpec

model_name_1 = "Llama-2-7b-hf"
model_name_2 = "Mistral-7B-v0.1"

mode = "set_swap"

save_path_1 = f"Exps/exp_mactivation_attnsink/{model_name_1}/{mode}_activation_magnitude.npy"
save_path_2 = f"results/{model_name_1}/rms2_in_bos.npy"
save_path_3 = f"results/{model_name_1}/ffn_out_bos.npy"


save_path_4 = f"Exps/exp_mactivation_attnsink/{model_name_2}/{mode}_activation_magnitude.npy"
save_path_5 = f"results/{model_name_2}/rms2_in_bos.npy"
save_path_6 = f"results/{model_name_2}/ffn_out_bos.npy"


data_1 = torch.from_numpy(np.load(save_path_1))
data_2 = torch.from_numpy(np.load(save_path_2))
data_3 = torch.from_numpy(np.load(save_path_3))
data_4 = torch.from_numpy(np.load(save_path_4))
data_5 = torch.from_numpy(np.load(save_path_5))
data_6 = torch.from_numpy(np.load(save_path_6))

print(data_1.shape)
print(data_2.shape)


data = {
  "Llama2-7B":{
    "Before_swap": (data_2+data_3)[:,:,0].norm(p=2,dim=-1).mean(dim=0),
    "k=1": data_1[0,:,:,0].norm(p=2,dim=-1).mean(dim=0),
    "k=2": data_1[1,:,:,0].norm(p=2,dim=-1).mean(dim=0),
    "k=3": data_1[2,:,:,0].norm(p=2,dim=-1).mean(dim=0),
    "k=4": data_1[3,:,:,0].norm(p=2,dim=-1).mean(dim=0),
    "k=5": data_1[4,:,:,0].norm(p=2,dim=-1).mean(dim=0)
  }
}


T = np.array(range(1,33))
T_prime = T


fig, ax = plt.subplots(1, 1, figsize=(10, 5))

# markers = {
#     "Before_norm": "o",
#     "After_norm": "*",
# }

model = "Llama2-7B"
curves = data["Llama2-7B"]
for eps, y in curves.items():
  ax.plot(
            T, y,
            # marker=markers[eps],
            linewidth=2,
            markersize=8,
            label=rf"{eps}"
        )

ax.set_title(model, fontsize=18)
ax.set_xlabel(r"$Block$", fontsize=18)
ax.set_xticks(T)
ax.set_xticklabels(T_prime)
ax.grid(True, alpha=0.3)
ax.legend(fontsize=20)

ax.set_ylabel(r"$L_2$-norm", fontsize=18)

plt.tight_layout()

save_path = f"imgs/exp_mactivation_attnsinks"
plt.savefig(
    save_path,
    dpi=300,                 
    bbox_inches="tight"      
)








