from selfie.gemma_util import *
import torch
from tqdm import tqdm
model_path = "/root/autodl-tmp/autodl-tmp/model/models--google--gemma-2-9b-it"
from transformers import AutoTokenizer,AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16,)
tokenizer = AutoTokenizer.from_pretrained(model_path)

target_words = ["key"]
target_positions = [1]

text = "The key was rusty and no longer fit the lock."
# text = "The key was rusty, but opening the door to new possibilities."
# text = "This key was rusty and can't open the door."
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True, return_offsets_mapping=True).to("cuda")
output = model.forward(**inputs, return_dict=True, output_attentions=True, output_hidden_states=True, specific_head=False, causle_mask_tag=False)
token_indices = get_token_indices_multiple(text, target_words, tokenizer, target_positions)
print(token_indices)
hidden_states = get_hidden_states_by_layer(output, token_indices)
# plot_token_hidden_states(hidden_states,target_words[0],"/root/autodl-tmp/1")


text_2 = "The key was rusty, but it opens new possibilities."
# text_2 = "He took a quantum leap in his career."
inputs_2 = tokenizer(text_2, return_tensors="pt", add_special_tokens=True, return_offsets_mapping=True).to("cuda")
output_2 = model.forward(**inputs_2, return_dict=True, output_attentions=True, output_hidden_states=True, specific_head=False, causle_mask_tag=False)
token_indices_2 = get_token_indices_multiple(text_2, target_words, tokenizer, target_positions)
print(token_indices_2)
hidden_states_2 = get_hidden_states_by_layer(output_2, token_indices_2)


mode="meaning_diff"
unused_index, sen = build_words_with_prompts_modechat(token_indices,tokenizer=tokenizer,mode=mode)
test_input = generate_all_insert_infos_combined(hidden_states,unused_index,mode=mode,inputs_position=1,second_hidden=hidden_states_2)

result = begin_2_interpret(test_input, 1, sen, tokenizer, model, max_length=300)


with open(f"/root/autodl-tmp/autodl-tmp/result/diff/interpretation_key_explain_it_from_my_first_impression_with_metaphor.txt", "a") as f:
    for i in tqdm(result):
        f.write("<begin>" + "\n")
        f.write(i + "\n")
        f.write("<end>" + "\n")