from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import torch

model_path = "llama_3.1_8b_QLoRA_left/checkpoint-35"
base_model = "meta-llama/Llama-3.1-8B-Instruct"

quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
    )

tokenizer = AutoTokenizer.from_pretrained(base_model)

base_model = AutoModelForCausalLM.from_pretrained(
    base_model,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=quantization_config,
)

# Load LoRA adapter on top of the base model
model = PeftModel.from_pretrained(base_model, model_path)
# model = model.merge_and_unload()  # Try this to enable generation

model.eval()

inputs = tokenizer("Hello, my name is XXXX-19", return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))