from transformers import AutoTokenizer

from modeling_llama_scale import LlamaForCausalLM

model_path="meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = LlamaForCausalLM.from_pretrained(model_path,
                                         device_map="auto",
                                         trust_remote_code=True,
                                         torch_dtype="auto",
                                         attn_implementation="flash_attention_2",).eval()

#configurations of the hidden states scaling
model.config.hidden_scale_config = {
    "layers": range(10,26), # the layers to apply the scaling
    "dims": [2393], # the dimensions to apply the scaling
    "factor": -1, # the scaling factor
    "skip_first": 0, # skip the first n tokens when scaling hidden states
    "last_recompute_tokens": 1, # the number of tokens whose attention weights are recomputed
    "change_value": False, # whether to change the value states. If False, only the query and key states are modified
}

messages = [
    {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you?"},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)


outputs = model.generate(
    input_ids,
    max_new_tokens=256,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
    use_cache=True,
)
response = outputs[0][input_ids.shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))

