import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer

def Find_vector(x_position, layer_id=1, token_id_1=1, token_id_2=0, model_name = "Llama-2-7b-hf"):

  save_dir = f"results/{model_name}"

  x_0_bos_path = f"{save_dir}/rms1_in_bos.npy"
  x_1_bos_path = f"{save_dir}/rms1_out_bos.npy"
  x_3_bos_path = f"{save_dir}/rms2_in_bos.npy"
  x_4_bos_path = f"{save_dir}/rms2_out_bos.npy"
  x_5_bos_path = f"{save_dir}/ffn_out_bos.npy"


  x_0_no_bos_path = f"{save_dir}/rms1_in_no_bos.npy"
  x_1_no_bos_path = f"{save_dir}/rms1_out_no_bos.npy"
  x_3_no_bos_path = f"{save_dir}/rms2_in_no_bos.npy"
  x_4_no_bos_path = f"{save_dir}/rms2_out_no_bos.npy"
  x_5_no_bos_path = f"{save_dir}/ffn_out_no_bos.npy"

  if x_position == "x_0":
    x_0_bos = torch.from_numpy(np.load(x_0_bos_path))
    x_0_no_bos = torch.from_numpy(np.load(x_0_no_bos_path))
    delta1 = x_0_no_bos[0,layer_id,token_id_2] - x_0_bos[0,layer_id,token_id_1]
    delta2 = (x_0_no_bos[:,layer_id,token_id_2] - x_0_bos[:,layer_id,token_id_1]).mean(dim=0)

    print(x_0_bos.shape)

  elif x_position == "x_1":
    x_1_bos = torch.from_numpy(np.load(x_1_bos_path))
    x_1_no_bos = torch.from_numpy(np.load(x_1_no_bos_path))
    delta1 = x_1_no_bos[0,layer_id,token_id_2] - x_1_bos[0,layer_id,token_id_1]
    delta2 = (x_1_no_bos[:,layer_id,token_id_2] - x_1_bos[:,layer_id,token_id_1]).mean(dim=0)

  elif x_position == "x_3":
    x_3_bos = torch.from_numpy(np.load(x_3_bos_path))
    x_3_no_bos = torch.from_numpy(np.load(x_3_no_bos_path))
    delta1 = x_3_no_bos[0,layer_id,token_id_2] - x_3_bos[0,layer_id,token_id_1]
    delta2 = (x_3_no_bos[:,layer_id,token_id_2] - x_3_bos[:,layer_id,token_id_1]).mean(dim=0)
  
  elif x_position == "x_4":
    x_4_bos = torch.from_numpy(np.load(x_4_bos_path))
    x_4_no_bos = torch.from_numpy(np.load(x_4_no_bos_path))
    delta1 = x_4_no_bos[0,layer_id,token_id_2] - x_4_bos[0,layer_id,token_id_1]
    delta2 = (x_4_no_bos[:,layer_id,token_id_2] - x_4_bos[:,layer_id,token_id_1]).mean(dim=0)

  elif x_position =="x_5":
    x_5_bos = torch.from_numpy(np.load(x_5_bos_path))
    x_5_no_bos = torch.from_numpy(np.load(x_5_no_bos_path))
    delta1 = x_5_no_bos[0,layer_id,token_id_2] - x_5_bos[0,layer_id,token_id_1]
    delta2 = (x_5_no_bos[:,layer_id,token_id_2] - x_5_bos[:,layer_id,token_id_1]).mean(dim=0)
  else:
    x_3_bos = torch.from_numpy(np.load(x_3_bos_path))
    x_5_bos = torch.from_numpy(np.load(x_5_bos_path))
    x_6_bos = x_3_bos + x_5_bos
    x_3_no_bos = torch.from_numpy(np.load(x_3_no_bos_path))
    x_5_no_bos = torch.from_numpy(np.load(x_5_no_bos_path))
    x_6_no_bos = x_3_no_bos + x_5_no_bos
    delta1 = x_6_no_bos[0,layer_id,token_id_2] - x_6_bos[0,layer_id,token_id_1]
    delta2 = (x_6_no_bos[:,layer_id,token_id_2] - x_6_bos[:,layer_id,token_id_1]).mean(dim=0)

  vec_random_path = f"results/Llama-2-7b-hf/vec_random.npy"
  vec_average_path = f"results/Llama-2-7b-hf/vec_average.npy"

  np.save(vec_random_path,delta1.detach().numpy())
  np.save(vec_average_path,delta2.detach().numpy())

if __name__ == "__main__":
  x_position = "x_0"
  layer_id=1
  token_id_1=1
  token_id_2=0
  Find_vector(x_position, layer_id, token_id_1, token_id_2)
  


# model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# ffn = model.model.layers[layer_id].mlp
# token_id_3 = 1

# alpha_list = [0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0]
# ans = []
# for alpha in alpha_list:
#   input = x_4_bos[0,layer_id,token_id_3] + alpha * delta
#   output = ffn(input).detach().cpu()
#   ans.append(output.norm(p=2,dim=-1).detach())


# x = range(len(ans))
# plt.plot(x, ans, label="L2norm")

# plt.legend()
# plt.xlabel("alpha")
# plt.ylabel("L2norm")
# plt.title("")
# plt.show()