import os
import sys
import re
import jsonlines
import json
import time

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from argparse import ArgumentParser
from utils.model_loader import VLLM_Model
from reader.vllm_reader import vllm_reader_batch
from utils.set_random_seed import set_random_seed
from utils.json_reader import jsonl_loader, json_loader

def llama_batch_inference(model_name, output_path):
    start = time.time()

    # define the parameters
    path_to_yml = "configs/config.yml"

    # load the model
    language_model = VLLM_Model(model_name, path_to_yml)
    model = language_model.load_model()
    tokenizer = language_model.load_tokenizer()
    params = language_model.load_config()

    system_prompt = ""
    
    # read the content of the evaluation dataset
    json_list = jsonl_loader("evaluations/google-research/instruction_following_eval/data/input_data.jsonl")
    
    # first stage generation
    query_list = []
    system_prompt_list = []

    for i in range(len(json_list)):
        query = json_list[i]["prompt"]
        query_list.append(query)
        system_prompt_list.append(system_prompt)
    
    output_list = vllm_reader_batch(model, tokenizer, params, query_list, system_prompt_list, batch_size=200)
    
    with jsonlines.open(output_path, mode="w") as w:
        for i in range(len(json_list)):
            query = json_list[i]["prompt"]
            output = output_list[i]
            w.write({"prompt":query,"response":output})
        
    w.close

    end = time.time()
    print(f"Time taken: {end-start}")

if __name__ == '__main__':
    parser= ArgumentParser()
    parser.add_argument("--model_name", default = "llama_model", type=str, help="Please enter the type of model to evaluate (llama_model or mistral_model)")
    parser.add_argument("--output_path", default = "", type=str, help="Please enter the path to save the generated result")

    args = parser.parse_args()
    llama_batch_inference(args.model_name, args.output_path)