import torch
from tqdm import tqdm
import pandas as pd

from transformers import AutoTokenizer

from modeling_llama_scale import LlamaForCausalLM

def format_chat_prompt(input,model_path,tokenizer):
    model_path=model_path.lower()


    if "mpt-7b-8k-instruct" in model_path:
        def format_prompt(instruction):
            template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n###Instruction\n{instruction}\n\n### Response\n"
            return template.format(instruction=instruction)
        prompt = format_prompt(input)

    elif "longchat" in model_path or "vicuna" in model_path:
        from fastchat.model import get_conversation_template
        conv = get_conversation_template("vicuna")
        conv.append_message(conv.roles[0], input)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

    else:
        messages=[{"role":"user","content":input}]
        chat_template = tokenizer.chat_template
        if chat_template is None:
            raise ValueError("not support chat_template")

        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    return prompt

if __name__ == '__main__':

    model_path="meta-llama/Llama-2-7b-chat-hf"
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    model = LlamaForCausalLM.from_pretrained(model_path,
                                             device_map="auto",
                                             trust_remote_code=True,
                                             torch_dtype="auto",
                                             attn_implementation="flash_attention_2",).eval()

    #configurations of the hidden states scaling
    model.config.hidden_scale_config = {
        "layers": range(10,26), # the layers to apply the scaling
        "dims": [2393], # the dimensions to apply the scaling
        "factor": -1, # the scaling factor
        "skip_first": 0, # skip the first n tokens when scaling hidden states
        "last_recompute_tokens": 1, # the number of tokens whose attention weights are recomputed
        "change_value": False, # whether to change the value states. If False, only the query and key states are modified
    }

    dataset = pd.read_json("KV60_valid_set.json", lines=True)
    losses = []
    device=model.device
    for i in tqdm(range(len(dataset))):
        prompt = dataset.iloc[i]["input"]
        prompt = format_chat_prompt(prompt, model_path, tokenizer)
        answer = dataset.iloc[i]["output"]

        prompt_ids = tokenizer.encode(prompt, return_tensors=None, padding=False, add_special_tokens=False)
        answer_ids = tokenizer.encode(answer, return_tensors=None, padding=False, add_special_tokens=False)
        input_ids = prompt_ids + answer_ids
        input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)
        # labels
        labels = [-100] * len(prompt_ids) + answer_ids
        labels = torch.tensor(labels).unsqueeze(0).to(device)
        attention_mask = torch.ones_like(input_ids).to(device)

        #the tokens in the answer_ids need to be recomputed
        model.config.hidden_scale_config["last_recompute_tokens"] = len(answer_ids) + 1

        with torch.no_grad():
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            loss = outputs.loss
            losses.append(loss.item())

    average_loss = sum(losses) / len(losses)
    print("average_loss:", round(average_loss, 4))