import torch

from transformers import AutoTokenizer, AutoModelForCausalLM

def check_model(model_id, save_dir):
    torch.cuda.empty_cache()

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(save_dir, device_map="cuda:0", torch_dtype="auto")

    input_ids = tokenizer("Q: hi there, what is your name? A:", return_tensors="pt").to(model.device)

    print(tokenizer.decode(model.generate(**input_ids, do_sample=False, max_length=100)[0], skip_special_tokens=True))

    torch.cuda.empty_cache()


if __name__ == "__main__":
    check_model("meta-llama/Meta-Llama-3.1-8B", "output-llama-finetune/checkpoint-200")
    check_model("meta-llama/Meta-Llama-3.1-8B", "meta-llama/Meta-Llama-3.1-8B")
