import os
import argparse
import json
import jsonlines
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

def load_data(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    return data


def write_output_jsonl(output_file, query, response, file):
    with jsonlines.open(output_file, mode='a') as writer:
        writer.write({"query": query, "response": response.replace("Output:",''), "file": file})


def main(args):
    # Load input data from JSON
    input_data = load_data(args.input_json)

    # Create a sampling params object
    sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=1024, seed=42)

    # Create an LLM instance
    llm = LLM(model=args.model_path, gpu_memory_utilization=0.80, enable_lora=True)
    tokenizer = llm.get_tokenizer()
    prompt = []
    for item in input_data:
        prompt.append([{"role": "user", "content": item["query"]}])
    # Prepare the inputs
    inputs = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)

    # Generate response
    outputs = llm.generate(prompts=inputs, sampling_params=sampling_params)
    output_jsonl=args.output_jsonl
    output_jsonl = output_jsonl.replace("jsonl", "") + "_" + args.model_path.split('/')[-1].replace("-", "_") + ".jsonl"
    # Extract response and write to JSONL
    for idx, output in enumerate(outputs):
        query = input_data[idx]['query']
        response = output.outputs[0].text
        file= input_data[idx]['file']
        write_output_jsonl(output_jsonl, query, response, file)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="LLM generation with VLLM")
    parser.add_argument('--model_path', type=str,  help='Path to the model',
    default="../local_model/Llama-3-8B-Instruct")
    parser.add_argument('--input_json', type=str,  help='Path to the input JSON file',
    default="../data/infer_rewrite/squad1.1.json")
    parser.add_argument('--output_jsonl', type=str,  help='Path to the output JSONL file',
    default="../data/tts/squad1.1.jsonl")
    args = parser.parse_args()

    main(args)