#import torch
#import os
from tqdm import tqdm

#from transformers import AutoModelForCausalLM, AutoTokenizer


def get_response_list_from_llama_3(messages_list,
                                   model_id="Meta-Llama-3-8B-Instruct",
                                   tokenizer=None, 
                                   model=None,
                                   temperature=1e-8, # ValueError: `temperature` (=0) has to be a strictly positive float, otherwise your next token scores will be invalid.
                                   max_tokens=512):
    # not sure why load model in get_response_list_from_llama_3 will have WARNING:root:Some parameters are on the meta device device because they were offloaded to the cpu.
    # guess the memory is not released fast enough...?
    # model_id = 'meta-llama/' + model_id
    # device = torch.device('cuda')
    # print(f'Load {model_id}\'s model and tokenizer...')
    # tokenizer = AutoTokenizer.from_pretrained(model_id)
    # #print(f'Tokenizer\'s chat_template:', tokenizer.chat_template)
    # assert tokenizer.chat_template is not None # ensure the template they trained with exists, or we have to do a lot of prompt engineering...
    # model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
    # model.generation_config.temperature = 1e-8 # ValueError: `temperature` (=0) has to be a strictly positive float, otherwise your next token scores will be invalid.
    # model.generation_config.top_p = 1
    # print('Model\'s generation_config:')
    # print(model.generation_config.to_dict())
    assert tokenizer is not None and model is not None
    response_list = []
    for messages in tqdm(messages_list) if len(messages_list) > 1 else messages_list:
        # Ref: https://huggingface.co/docs/transformers/chat_templating#how-do-i-use-chat-templates
        inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
        terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
        model.generation_config.temperature = temperature
        max_new_tokens = max_tokens
        outputs = model.generate(inputs,
                                 temperature=temperature,
                                 top_p=1,
                                 max_new_tokens=max_new_tokens,
                                 eos_token_id=terminators,
                                 pad_token_id=tokenizer.eos_token_id) # https://stackoverflow.com/questions/69609401/suppress-huggingface-logging-warning-setting-pad-token-id-to-eos-token-id
        # Ref: https://github.com/huggingface/transformers/issues/17117
        # this simple decode method will also contain prompt: response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        #print('Before:')
        #print(tokenizer.decode(outputs[0], skip_special_tokens=True))
        #print('After:')
        response = tokenizer.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0]
        #print(response)
        response_list.append(response)
        #s = input()
    return response_list
