import torch
from transformers import AutoTokenizer, LlamaForCausalLM

dimension = 8192                            # Attention dimension
attention_heads = 64                        # How many attention heads to use
head_size = dimension // attention_heads    # How many dimension in each attention head
max_length = 4096
model_name = "LLaMA2-70B"
tempreture = 300

device = torch.device('cuda')
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf")
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-70b-hf", device_map="auto").half()
model.eval()

@torch.no_grad()
def get_att(prompt: str) -> list:
    
    ips = tokenizer(
        prompt,
        return_tensors="pt")
    
    input_ids = ips.input_ids.to(device)

    if(input_ids.shape[-1] > max_length):
        return []
    
    att = model.model(
        input_ids=input_ids,
        return_dict=False,
    )[0][0][-1].cpu().tolist()

    return att

choices = {29909: 0, 29933: 1, 29907: 2, 29928: 3}

@torch.no_grad()
def get_length(prompt: str) -> int:
    inputs = tokenizer(prompt, return_tensors="pt")
    return  inputs.input_ids.shape[-1]

@torch.no_grad()
def get_ans(prompt: str, full_generate=True):

    inputs = tokenizer(prompt, return_tensors="pt")
    base_length = inputs.input_ids.shape[-1]

    if base_length > max_length:
        return 0, -1

    option = 0

    if full_generate:
        option = model.generate(inputs.input_ids.to(device), 
                                return_dict=False, 
                                max_length=4096)[0][base_length:].cpu().tolist()[0]
    else:
        option = model(inputs.input_ids.to(device), 
                       return_dict=False)[0][0][-1].cpu().argmax().item()
    
    human = 0
    answer = choices.get(option)
    if answer is None:
        human = 1
        answer = -1
    
    return answer, human

