import argparse
import json
import os
from transformers import AutoTokenizer
from tqdm import tqdm
from vllm import LLM, SamplingParams
from datasets import load_dataset
from prompt_templates import instruction_template, knowledge_template, npc_template, math_template
import ray
import torch

def request_input_format(user_prompt, tokenizer):
    system_prompt = "You are a helpful assistant."
    messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return text

def parse_arguments():
    parser = argparse.ArgumentParser(description="Synthesize text using a specified model and template.")
    parser.add_argument('--sample_size', type=int, default=100, help='Number of samples to process from the dataset; Set it to 0 if you want to use the full set of 200k personas.')
    parser.add_argument('--model_path', type=str, default="Qwen/Qwen2-72B-Instruct", help='Path to the model.')
    parser.add_argument('--output_path', type=str, default='./test_prompt.json', help='Path to the output file.')
    parser.add_argument(
        '--template', 
        type=str, 
        default='instruction',
        choices=['instruction', 'knowledge', 'npc', 'math'], 
        help=(
            "Prompt templates. Choose from 'instruction', 'knowledge', 'math' or 'npc'. "
            "You can also add more customized templates in code/templates.py"
        )
    )
    return parser.parse_args()

# @ray.remote(num_gpus=1)
def generate_responses(prompts, model_path, max_len):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    llm = LLM(model=model_path, tensor_parallel_size=4) # Adjust tensor_parallel_size based on the GPUs you are using
    
    sampling_params = SamplingParams(temperature=0.6, top_p=0.95, max_tokens=max_len, stop_token_ids=[tokenizer.eos_token_id])
    outputs = llm.generate(prompts, sampling_params)
    
    return outputs

def main(args):
    # ray.init()

    # Load the appropriate template
    if args.template == "instruction":
        template = instruction_template
    elif args.template == "knowledge":
        template = knowledge_template
    elif args.template == "npc":
        template = npc_template
    elif args.template == "math":
        template = math_template
    else:
        raise ValueError("Invalid template type. Choose from 'instruction', 'knowledge', 'math' or 'npc'.")

    # Load the dataset
    persona_dataset = load_dataset("proj-persona/PersonaHub", data_files="persona.jsonl")['train']
    if args.sample_size > 0:
        persona_dataset = persona_dataset.select(range(args.sample_size))
    print(f"Total number of input personas: {len(persona_dataset)}")

    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    tokenizer.pad_token = tokenizer.eos_token

    prompts = []
    max_len = 2048

    for persona in persona_dataset['persona']:
        persona = persona.strip()
        user_prompt = template.format(persona=persona)
        prompt = request_input_format(user_prompt, tokenizer)
        prompts.append(prompt)

    print(f"Loaded {len(prompts)} entries to process...\n\n")
    print(f"Sample 0: {prompts[0]}")

    # Calculate the batch size based on the number of available GPUs *(70B--> 1)
    num_gpus = 1
    batch_size = (len(prompts) + num_gpus - 1) // num_gpus
    batches = [prompts[i:i + batch_size] for i in range(0, len(prompts), batch_size)]

    # Generate responses in parallel
    futures = [generate_responses(batch, args.model_path, max_len) for batch in batches]
    chunked_outputs = futures

    outputs = [output for chunk in chunked_outputs for output in chunk]

    model_name = args.model_path.split('/')[-1]
    output_file_path = f"{args.output_path}"
    
    final_dataset = []
    for i, output in enumerate(outputs):
        out_txt = output.outputs[0].text
        finish_reason = output.outputs[0].finish_reason
        synthesized_origin = out_txt

        prefixes_to_remove = ["User prompt:", "Title:", "Name:", "Math problem:","**Title:**"]
        for prefix in prefixes_to_remove:
            if out_txt.startswith(prefix):
                out_txt = out_txt[len(prefix):].strip()
        data = {
            'prompt': output.prompt,
            'input persona': persona_dataset['persona'][i].strip(),
            'finish_reason': finish_reason,
            'synthesized_origin': synthesized_origin,
            'synthesized_prompt': out_txt
        }
        final_dataset.append(data)

    # Ensure the output directory exists
    os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
    
    # Save the results to the output file
    with open(output_file_path, 'w') as output_file:
        json.dump(final_dataset, output_file, ensure_ascii=False, indent=2)

    ray.shutdown()

if __name__ == "__main__":
    args = parse_arguments()
    main(args)
