import argparse
import json
import os
from transformers import AutoTokenizer
from tqdm import tqdm
from vllm import LLM, SamplingParams
from datasets import load_dataset
from prompt_templates import instruction_template, knowledge_template, npc_template, math_template,instruction_ablation_template
import ray
import torch
import re
from concurrent.futures import ThreadPoolExecutor


def request_input_format(user_prompt, tokenizer,gemma=False):
    system_prompt = "You are a helpful assistant."
    messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
    if gemma == True:
        messages = [{"role": "user", "content": user_prompt}]
        
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 

def parse_arguments():
    parser = argparse.ArgumentParser(description="Synthesize text using a specified model and template.")
    parser.add_argument('--sample_size', type=int, default=0, help='Number of samples to process from the dataset; Set it to 0 if you want to use the full set of 200k personas.')
    parser.add_argument('--model_path', type=str, required=True, help='Path to the model.')
    parser.add_argument('--output_path', type=str, required=True, help='Path to the output file.')
    parser.add_argument(
        '--template', 
        type=str, 
        required=True, 
        choices=['instruction', 'knowledge', 'npc', 'math','instruction_ablation'], 
        help="Prompt templates. Choose from 'instruction', 'knowledge', 'math' or 'npc'."
    )
    return parser.parse_args()

@ray.remote(num_gpus=1)
class ModelWorker:
    def __init__(self, model_path):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.llm = LLM(model=model_path, tensor_parallel_size=1,enforce_eager=True)
        self.sampling_params = SamplingParams(temperature=0.6, top_p=0.95, max_tokens=2048, 
                                              stop_token_ids=[self.tokenizer.eos_token_id, 
                                                              self.tokenizer.convert_tokens_to_ids("<|eot_id|>")])

    def generate(self, prompts):
        return self.llm.generate(prompts, self.sampling_params)

def process_output(output, persona, prefix_regex):
    out_txt = output.outputs[0].text
    finish_reason = output.outputs[0].finish_reason
    synthesized_origin = out_txt
    out_txt = prefix_regex.sub('', out_txt).strip()
    
    return {
        'prompt': output.prompt,
        'input_persona': persona.strip(),
        'finish_reason': finish_reason,
        'synthesized_origin': synthesized_origin,
        'synthesized_prompt': out_txt
    }

def main(args):
    ray.init()

    template = globals()[f"{args.template}_template"]

    persona_dataset = load_dataset("proj-persona/PersonaHub", data_files="persona.jsonl")['train']
    if args.sample_size > 0:
        persona_dataset = persona_dataset.select(range(args.sample_size))
    print(f"Total number of input personas: {len(persona_dataset)}")

    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    tokenizer.pad_token = tokenizer.eos_token

    gemma = False
    if "gemma" in args.model_path:
        gemma = True


    prompts = [request_input_format(template.format(persona=persona.strip()), tokenizer,gemma=gemma) 
               for persona in persona_dataset['persona']]

    if args.template == "instruction_ablation":
        prompts = [request_input_format(template, tokenizer) 
               for persona in persona_dataset['persona']]


    print(f"Loaded {len(prompts)} entries to process...\n\n")
    print(f"Sample 0: {prompts[0]}")

    num_gpus = torch.cuda.device_count()
    batch_size = (len(prompts) + num_gpus - 1) // num_gpus
    batches = [prompts[i:i + batch_size] for i in range(0, len(prompts), batch_size)]

    workers = [ModelWorker.remote(args.model_path) for _ in range(num_gpus)]
    futures = [worker.generate.remote(batch) for worker, batch in zip(workers, batches)]
    
    chunked_outputs = ray.get(futures)
    
    ray.shutdown()

    prefix_regex = re.compile('|'.join([
        r'^User prompt:\s*', r'^Title:', r'^Name:', r'^Math problem:', r'^\*\*Title\*\*:',r'^\"User prompt:\s*'
    ]))

    prefix_regex = re.compile('|'.join([
    r'^\s*User prompt:\s*',  # Allow leading spaces before "User prompt:"
    r'^\s*Title:\s*',        # Allow leading spaces before "Title:"
    r'^\s*Name:\s*',         # Allow leading spaces before "Name:"
    r'^\s*Math problem:\s*', # Allow leading spaces before "Math problem:"
    r'^\s*\*\*Title\*\*:\s*',# Allow leading spaces before "**Title**:"
    r'^\s*\"User prompt:\s*' # Allow leading spaces before "\"User prompt:"
    ]))

    with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
        final_dataset = list(tqdm(
            executor.map(process_output, 
                         (output for chunk in chunked_outputs for output in chunk),
                         persona_dataset['persona'],
                         [prefix_regex] * len(persona_dataset)),
            total=len(persona_dataset)
        ))

    os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
    
    with open(args.output_path, 'w') as output_file:
        json.dump(final_dataset, output_file, ensure_ascii=False, indent=2)

if __name__ == "__main__":
    args = parse_arguments()
    if "gemma" in args.model_path:
        os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER"
    main(args)