import os
import torch
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer
from datasets import Dataset
import transformers
import gc

model_ckpt = "../Llama-2-7b-chat-hf"

tokenizer = AutoTokenizer.from_pretrained(model_ckpt, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_ckpt,
    device_map="auto",
    trust_remote_code=True
)

lora_config = LoraConfig.from_pretrained("../results/final_checkpoint")


model = get_peft_model(model, lora_config)
model = model.merge_and_unload()

pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto",
)

prompt = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n'

while True:
    text = input('Q:')
    text = prompt + '### Instruction:\n' + text + '\n\n### Response:\n'

    sequences = pipeline(
        text,
        eos_token_id=tokenizer.eos_token_id,
        max_length=200,
        num_beams=1,
        repetition_penalty=2.0,
    )
    print('A:', sequences[0]['generated_text'].replace(text, ''))