import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
import torch


model_path = "./llama3.2"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path)

input_text = "The quick brown fox jumps over the lazy dog."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

outputs = model(input_ids)


# print(f"model:{model}")

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

def repeat_weights(W_k: torch.Tensor, n_rep = 4) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    k_head_dim, hidden_dim = W_k.shape
    if n_rep == 1:
        return W_k
    W_k = W_k[:, None,:].expand( k_head_dim , n_rep, hidden_dim)
    return W_k.reshape(k_head_dim * n_rep, hidden_dim)

def gen_real_W_q_W_k():
    first_decoder_layer = model.layers[0]
    self_attention = first_decoder_layer.self_attn

    W_q = self_attention.q_proj.weight # (2048, 2048)
    W_k = self_attention.k_proj.weight # (512, 2048)

    W_k = repeat_weights(W_k, 4)

    W_q_numpy = W_q.detach().cpu().numpy()
    W_k_numpy = W_k.detach().cpu().numpy()

    # save W_q and W_k
    np.save("W_q.npy", W_q_numpy)
    np.save("W_k.npy", W_k_numpy)



print("get prarameter test:")
gen_real_W_q_W_k()
