from selfie.gemma_util import *
import torch
import tqdm

model_path = "/root/autodl-tmp/autodl-tmp/model/models--google--gemma-2-9b-it"
from transformers import AutoTokenizer,AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16,)
tokenizer = AutoTokenizer.from_pretrained(model_path)

prompt = "This key was rusty but it opens new possibilities." 


inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, return_offsets_mapping=True).to("cuda")
offset_mapping = inputs["offset_mapping"]
token_indices_list = get_token_indices(prompt, tokenizer)
for word_idx, indices in enumerate(token_indices_list):
    print(f"Word {word_idx + 1}: Token indices {indices}")

output = model.forward(**inputs, return_dict=True, output_attentions=True, output_hidden_states=True, specific_head=True)
# generate_ids = model.generate(inputs.input_ids, max_length=200)
# cropped_interpretation = tokenizer.batch_decode(generate_ids, skip_special_tokens=True)
# print(cropped_interpretation)

# input_ids = inputs["input_ids"]
# pro_prompt="Here is the right place to post this:"
# # after_prompt=". Now, here is a ways:"
# pro_prompt = "Summarizing the main idea of a text in a few words."
# after_prompt = "Now, here is your answer:"
# unused_index, sen = build_sentence_with_prompts(token_indices_list,tokenizer=tokenizer, prompt=pro_prompt,prompt2=after_prompt)
# prompt2="<start_of_turn>model\nAs Sam's perspective, the rubber duck in its initial location by the end of the story:"
# prompt2="<start_of_turn>model\nSure, I'll repeat your message:"
# prompt2='''<start_of_turn>model\nSure, I'll explain the meaning about the "key" in this sentence:'''
# prompt2='''<start_of_turn>model\nSure, I'll summary the concept about the key in your message,'''
# prompt2='''<start_of_turn>model\nSure, I'll summary your message around the "key":'''
prompt2='''<start_of_turn>model\nSure, I'll summary the meaning of your message:'''

unused_index, sen = build_sentence_with_prompts_modechat(token_indices_list,tokenizer=tokenizer, prompt2=prompt2)
print(unused_index)
# all_insert_infos = []
for i in tqdm.tqdm(range(len(output.single_att))):
    all_insert_infos = []
    for j in output.single_att[i]:
        a = hidden_states_to_insert_infos(j,unused_index,inputs_position=1)
        all_insert_infos.append(a)
    
    print("---" * 10)
    print(f"the layer {i} is coming")
    result = begin_2_interpret(all_insert_infos, 1, sen, tokenizer, model, max_length=300)
    print("---" * 10)

    # 写入文件，本地保存
    with open(f"/root/autodl-tmp/autodl-tmp/result/key_word_m_summary_meaning_300_0/interpretation_{i}.txt", "w") as f:
        for i in result:
            f.write("<begin>" + "\n")
            f.write(i + "\n")
            f.write("<end>" + "\n")
