import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# 模型路径
model_path = "/root/autodl-tmp/autodl-tmp/model/models--google--gemma-2-9b-it"
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto",)
tokenizer = AutoTokenizer.from_pretrained(model_path)

# 输入句子对
extended_text = [
    "She picked up a small rock from the hiking trail as a keepsake.",
    "The customer went to the bank early in the morning to watch the water flow gently along the riverbed."
]

# 初始化存储隐藏状态的列表
hidden_states_list = []

# 遍历每个句子
for text in extended_text:
    # Tokenize 文本并生成输入
    inputs = tokenizer(text, return_tensors="pt").to("cuda")
    # 前向传播，获取隐藏状态
    output = model(**inputs, return_dict=True, output_hidden_states=True, output_attentions=True)
    # 提取第二层隐藏状态（可以调整为其他层）
    hidden_states = output.hidden_states[-1]  # 第二层隐藏状态
    hidden_states_list.append(hidden_states)
    attention_weights = output.attentions 
    # 提取不同层的隐藏状态
    attention_weights_layer_1 = output.attentions[1][0, :, :10, :10]

    # 可视化前 10 个 Token 的注意力分布
    import seaborn as sns
    import matplotlib.pyplot as plt

    sns.heatmap(attention_weights_layer_1.mean(dim=0).detach().cpu().numpy(), cmap="viridis")
    plt.title("Attention Weights (Layer 1, First 10 Tokens)")
    plt.xlabel("Key Tokens")
    plt.ylabel("Query Tokens")
    plt.savefig('/root/autodl-tmp/attention_weights_layer_1.png')



# # 检查是否有两个句子的隐藏状态
# if len(hidden_states_list) == 2:
#     # 提取前10个 Token 的隐藏状态
#     hidden_1 = hidden_states_list[0][:, :13, :]  # 第一个句子
#     hidden_2 = hidden_states_list[1][:, :13, :]  # 第二个句子
    
#     # 计算均方差 (MSE)
#     mse = torch.mean((hidden_1 - hidden_2) ** 2)
#     print(f"Mean Squared Error (MSE) for the first 10 tokens: {mse.item()}")

#     # 提取后续 Token 的隐藏状态
#     hidden_1_tail = hidden_states_list[0][:, 10:, :]
#     hidden_2_tail = hidden_states_list[1][:, 10:, :]   
#     # 对齐较短的句子
#     min_length = min(hidden_1_tail.size(1), hidden_2_tail.size(1))
#     hidden_1_tail = hidden_1_tail[:, :min_length, :]
#     hidden_2_tail = hidden_2_tail[:, :min_length, :]
    
#     # 计算后续 Token 的均方差
#     mse_tail = torch.mean((hidden_1_tail - hidden_2_tail) ** 2)
#     print(f"Mean Squared Error (MSE) for the remaining tokens: {mse_tail.item()}")
# else:
#     print("Error: Hidden states for both sentences were not computed.")
