from selfie.gemma_util import *
import torch

model_path = "/root/autodl-tmp/autodl-tmp/model/models--google--gemma-2-9b-it"
from transformers import AutoTokenizer,AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16,)
tokenizer = AutoTokenizer.from_pretrained(model_path)

text = "Lawyers are like foxes"


# Use the tokenizer to convert it to tokens. Note that this implicitly adds a special "Beginning of Sequence" or <bos> token to the start
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True, return_offsets_mapping=True).to("cuda")
# offset_mapping = inputs["offset_mapping"]
output = model.forward(**inputs, return_dict=True, output_attentions=True, output_hidden_states=True, specific_head=True)

target_words = ["like"]
target_positions = [1]
token_indices = get_token_indices_multiple(text, target_words, tokenizer, target_positions)
print(token_indices)
hidden_states = get_hidden_states_by_layer(output, token_indices, [0,5,10,15,20,25,30,35,40])
mode="diff"
unused_index, sen = build_words_with_prompts_modechat(token_indices,tokenizer=tokenizer,mode=mode)

print(1)
test_input = generate_all_insert_infos_combined(hidden_states,unused_index,mode=mode,inputs_position=3)
# print(test_input)

# # print(test_input[0][1][0])
# # print(test_input[0][1][0][0])
# # print(test_input[0][1][0][1])
# # print(test_input[0][1][0][1].shape)


begin_2_interpret(test_input, 1, sen, tokenizer, model, max_length=600)