import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.gridspec import GridSpec

model_name_1 = "Llama-2-7b-hf"
model_name_2 = "Meta-Llama-3-8B"

def draw(model_name):

  x_3 = np.load(f"./results/{model_name}/rms2_in_bos.npy")
  x_5 = np.load(f"./results/{model_name}/ffn_out_bos.npy")
  x_6 = x_3 + x_5
  idx = [0,1,2,9,19,31]
  X = x_6[0, idx, 0, :]
  layer_id = [1,2,3,10,20,32]

  fig = plt.figure(figsize=(12, 6), dpi=400)
  gs = GridSpec(1, 2, figure=fig, width_ratios=[1,1], height_ratios=[1], wspace=0.5)  # 高度一样

  ax3d = fig.add_subplot(gs[0, 0], projection='3d')

  dim_idx = np.arange(X.shape[1])
  line_color = "#1f77b4"

  for i in range(X.shape[0]):
      ax3d.plot(
          dim_idx,
          np.full_like(dim_idx, i),
          X[i],
          linewidth=1.5,
          alpha=0.95,
          color=line_color
      )

  ax3d.set_ylabel("Layer index", fontsize=11, labelpad=8)

  if model_name == "Llama-2-7b-hf":
    special_positions = [1415, 2533]
  else:
    special_positions = [788, 1384, 4062]
  ax3d.set_xticks(special_positions)
  ax3d.set_xticklabels([f"{i}" for i in special_positions])

  ax3d.set_yticks(np.arange(len(layer_id)))
  ax3d.set_yticklabels([f"{i}" for i in layer_id])

  ax3d.grid(True, linestyle="--", alpha=0.3)
  ax3d.tick_params(axis='both', labelsize=9)

  ax3d.view_init(elev=25, azim=-60)

  ax2d = fig.add_subplot(gs[0, 1])
  x2 = [i for i in range(1,33)]
  y2 = x_6
  y2 = np.linalg.norm(y2[0,:,0],axis=1)
  ax2d.plot(x2, y2, color="#ff4500", linewidth=2, marker='o', markersize=4)
  ax2d.set_xlabel("Layer index")
  ax2d.set_ylabel("$L_2$-norm")
  ax2d.grid(True, linestyle="--", alpha=0.4)
  ax2d.legend()
  ax2d.set_facecolor('#ffffff')  
  if model_name == "Llama-2-7b-hf":
    fig.suptitle("Llama2-7B", fontsize=32, y=1)  
  else:
    fig.suptitle("Llama3-8B", fontsize=32, y=1)

  plt.tight_layout()
  fig.savefig(f"./imgs/{model_name}_basic.png", dpi=400, bbox_inches='tight')  
  plt.close(fig)

if __name__ == "__main__":
  draw(model_name_1)
  draw(model_name_2)

