import os
import sys
import re
import jsonlines
import json
import time

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from module_01_preprocess.prompt_template import question_rewrite_system_prompt, question_rewrite_user_prompt
from utils.model_loader import VLLM_Model
from reader.vllm_reader import vllm_reader_batch
from utils.set_random_seed import set_random_seed
from utils.json_reader import jsonl_loader
from argparse import ArgumentParser

def extract_phrases(text):
    # Use regex to find numbered phrases
    phrases = re.findall(r'\d+\.\s([^\n]+)', text)
    return phrases

def llama_batch_inference(jsonl_input, jsonl_output):
    start = time.time()
    set_random_seed(42)

    # define the parameters
    path_to_yml = "configs/config.yml"
    model_name = "qwen_32b_model"

    # load the model
    language_model = VLLM_Model(model_name, path_to_yml)
    model = language_model.load_model()
    tokenizer = language_model.load_tokenizer()
    params = language_model.load_config()
    
    # read the content of the evaluation dataset
    json_list = jsonl_loader(jsonl_input)
    
    # first stage generation
    query_list = []
    system_prompt_list = []

    for i in range(len(json_list)):
        question = json_list[i]["question"]
        answer = json_list[i]["answer"]
        query = question_rewrite_user_prompt.format(question=question, answer=answer)
        query_list.append(query)
        system_prompt_list.append(question_rewrite_system_prompt)
    
    output_list = vllm_reader_batch(model, tokenizer, params, query_list, system_prompt_list, batch_size=150)

    object = []
    for i in range(len(output_list)):
        new_dict = {}
        new_dict["question"] = json_list[i]["question"]
        new_dict["answer"] = re.sub(r'\[\d+\]', '', json_list[i]["answer"]) # remove tag such as [1], [2] from the answer
        new_dict["new_question"] = extract_phrases(output_list[i])
        object.append(new_dict)

    with open(jsonl_output, "w", encoding="utf-8") as w:
        json.dump(object, w, indent=2, ensure_ascii=False)    

    end = time.time()
    print(f"Time taken: {end-start}")

if __name__ == '__main__':
    parser= ArgumentParser()
    parser.add_argument("--jsonl_input", type=str, default="dataset/natural_question/natural_question.jsonl", help="Input JSONL file path")
    parser.add_argument("--jsonl_output", type=str, default="dataset/natural_question/natural_question_01.jsonl", help="Output JSONL file path")
    args = parser.parse_args()
    llama_batch_inference(args.jsonl_input, args.jsonl_output)


