import os
import sys
import re
import jsonlines
import json
import time

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from module_05_finetune.prompt_template import system_prompt_template_v4
from utils.model_loader import Model, DS_Model,DS2_Model, lora_Model, VLLM_Model
from utils.prompt_template import user_prompt_template
from reader.llama_reader import llama_reader, llama_reader_batch
from reader.ds_reader import ds_reader, ds_reader_batch
from reader.ds2_reader import ds2_reader, ds2_reader_batch
from reader.vllm_reader import vllm_reader, vllm_reader_batch
from utils.set_random_seed import set_random_seed
from utils.json_reader import jsonl_loader, json_loader

def llama_batch_inference():
    start = time.time()

    # define the parameters
    path_to_yml = "configs/config.yml"
    # model_name = "mistral_7b_model"
    model_name = "llama_2_7b_chat_model"
    # load the model
    language_model = VLLM_Model(model_name, path_to_yml)
    model = language_model.load_model()
    tokenizer = language_model.load_tokenizer()
    params = language_model.load_config()
    
    # read the content of the evaluation dataset
    json_list = jsonl_loader("module_06_evaluation/dataset/lifestyle_rewrite_question.jsonl")
    
    # first stage generation
    query_list = []
    system_prompt_list = []
    system_prompt = "You are a helpful AI assistant."

    for i in range(len(json_list)):
        query = json_list[i]["modified_question"]
        query_list.append(query)
        system_prompt_list.append(system_prompt)
    
    output_list = vllm_reader_batch(model, tokenizer, params, query_list, system_prompt_list, batch_size=200)
    
    with jsonlines.open("results/domain/llama_base_lifestyle_200.jsonl",mode="w") as w:
        for i in range(len(json_list)):
            query = json_list[i]["modified_question"]
            output = output_list[i].strip()
            ground_truth = json_list[i]["original_answer"]

            w.write({"prompt":query,"response":output, "ground_truth":ground_truth})
        
    w.close

    end = time.time()
    print(f"Time taken: {end-start}")

if __name__ == '__main__':
    set_random_seed(42)
    llama_batch_inference()


