import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "../../../huggingface/palmyra-med-20b"

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.float16,
)

prompt = "Can you explain in simple terms how vaccines help our body fight diseases?"

input_text = (
    "A chat between a curious user and an artificial intelligence assistant. "
    "The assistant gives helpful, detailed, and polite answers to the user's questions. "
    "USER: {prompt} "
    "ASSISTANT:"
)

model_inputs = tokenizer(input_text.format(prompt=prompt), return_tensors="pt").to(
    "cuda"
)

gen_conf = {
    "temperature": 0.7,
    "repetition_penalty": 1.0,
    "max_new_tokens": 512,
    "do_sample": True,
}

out_tokens = model.generate(**model_inputs, **gen_conf)

response_ids = out_tokens[0][len(model_inputs.input_ids[0]) :]
output = tokenizer.decode(response_ids, skip_special_tokens=True)

print(output)
## output ##
# Vaccines stimulate the production of antibodies by the body's immune system.
# Antibodies are proteins produced by B lymphocytes in response to foreign substances,such as viruses and bacteria.
# The antibodies produced by the immune system can bind to and neutralize the pathogens, preventing them from invading and damaging the host cells.
# Vaccines work by introducing antigens, which are components of the pathogen, into the body.
# The immune system then produces antibodies against the antigens, which can recognize and neutralize the pathogen if it enters the body in the future.
# The use of vaccines has led to a significant reduction in the incidence and severity of many diseases, including measles, mumps, rubella, and polio.
