import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

dimension = 7168                            # Attention dimension
attention_heads = 56                        # How many attention heads to use
head_size = dimension // attention_heads    # How many dimension in each attention head
max_length = 4096
model_name = "Yi"
tempreture = 1200
model_batch_size = 1

device = torch.device('cuda')
tokenizer = AutoTokenizer.from_pretrained("01-ai/Yi-34B", trust_remote_code=True)
tokenizer.padding_side = "left"
# tokenizer.add_special_tokens({'pad_token': '<s>'})
model = AutoModelForCausalLM.from_pretrained("01-ai/Yi-34B", trust_remote_code=True, device_map="auto").half()
model.eval()


@torch.no_grad()
def get_att(prompt: str) -> torch.Tensor:
    
    # ips = torch.cat([tokenizer(prompt, return_tensors="pt")["input_ids"], torch.LongTensor([[187]])], 
    #                 dim=-1)
    # prompt = cut_prompt(prompt)
    ips = tokenizer(prompt, return_tensors="pt", padding=True)["input_ids"]
    base_length = ips.shape[-1]

    if(base_length > max_length):
        return torch.Tensor([0])
    
    att = model.model(
        input_ids=ips.to(device),
    ).last_hidden_state[0, -1,].cpu().tolist()

    return att

choices = {370: 0, 371: 1, 372: 2, 373: 3, 374: 4, 375: 5, 376: 6, 
           647: 0, 650: 2, 690: 1, 721: 5, 723: 3, 756: 6, 764: 4, 
           59603: 0, 59608: 2, 59613: 4, 59614: 3, 59616: 1, 59621: 5, 59628: 6,
           }

@torch.no_grad()
def get_length(prompt: str) -> int:
    # prompt = cut_prompt(prompt)
    return tokenizer(prompt, return_tensors="pt", padding=True)["input_ids"].shape[-1]

@torch.no_grad()
def get_ans(prompts: str, full_generate=True):

    # This model dose not support full_genetare.

    # prompt = cut_prompt(prompt)
    ips = tokenizer(prompts, return_tensors="pt", padding=True)["input_ids"]
    base_length = ips.shape[-1]

    if base_length > max_length:
        return 0, -1

    # if full_generate:
    #     option = model.generate(inputs.input_ids.to(device), 
    #                             return_dict=False, 
    #                             max_length=4096)[0][base_length:].cpu().tolist()[0]
    # else:
    #     option = model(inputs.input_ids.to(device), 
    #                    return_dict=False)[0][0][-1].cpu().argmax().item()
    
    option = model(ips.to(device))["logits"].cpu().argmax(dim=-1)[0, -1]

    proc = choices.get(option.item())
    ff = 0
    if proc is None: 
        proc = -1
        ff = 1
    return proc, ff


if __name__ == "__main__":

    text = '''What's the right answer? Output A B C or D.
A. The wrong answer.
B. The right answer.
C. The wrong answer.
D. The wrong answer.
Answer:
'''
    text2 = '''What's the best answer you think? Output A B C or D.
A. The wrong answer.
B. The right answer.
C. The wrong answer.
D. The wrong answer.
Answer:
'''
    # input_ids = tokenizer(text, return_tensors="pt")["input_ids"]
    # print(input_ids)
    # output = model(input_ids.to(device))["logits"][0].cpu()
    # o_tokens = output.argmax(dim=-1)
    # print(o_tokens.shape)
    # print(tokenizer.decode(o_tokens))
    
    for k, v in choices.items():
        print("Key: {} Val: {} Chr: {}".format(k, v, tokenizer.decode(k)))

    for i in range(60000):
        token = tokenizer.decode([i])
        if token == 'A' or token == ' A' or token == 'A.':
            print(f"{i}: 0, ", end='')
        if token == 'B' or token == ' B':
            print(f"{i}: 1, ", end='')
        if token == 'C' or token == ' C':
            print(f"{i}: 2, ", end='')
        if token == 'D' or token == ' D':
            print(f"{i}: 3, ", end='')
        if token == 'E' or token == ' E':
            print(f"{i}: 4, ", end='')
        if token == 'F' or token == ' F':
            print(f"{i}: 5, ", end='')
        if token == 'G' or token == ' G':
            print(f"{i}: 6, ", end='')
    print()

    print(model)
    