import argparse
import json
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import ray
import torch

def parse_arguments():
    parser = argparse.ArgumentParser(description="Generate multiple responses for prompts")
    parser.add_argument("--model_id", type=str, default="Qwen/Qwen2-72B-Instruct", help="Model ID to use")
    parser.add_argument("--input_file", type=str, default="./test_prompt.json", help="Path to input JSONL file")
    parser.add_argument("--output_file", type=str, default="./test_response.json", help="Path to output JSONL file")
    parser.add_argument("--num_responses", type=int, default=2, help="Number of responses to generate for each prompt")
    return parser.parse_args()

# @ray.remote(num_gpus=1)
def generate_responses(prompts, personas, model_id, num_responses):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    llm = LLM(model=model_id, tensor_parallel_size=4)
    
    sampling_params = SamplingParams(
        temperature=0.6,
        top_p=0.95,
        max_tokens=1024,
        stop_token_ids=[tokenizer.eos_token_id]
    )

    messages = [[{"role": "user", "content": prompt}] for prompt in prompts]
    formatted_prompts = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages]

    all_responses = []
    for time in range(1, num_responses + 1):
        if time % 2 == 0:
            cot_formatted_prompts = [
                tokenizer.apply_chat_template(
                    [{"role": "user", "content": f"You are a {persona} Let's think it step by step! \n\n This is an unsatisfying answer and need to further correct and improve: \n '''{previous_answer}'''\n\n {prompt}  "}],
                    tokenize=False, 
                    add_generation_prompt=True
                ) for prompt, persona, previous_answer in zip(prompts, personas, all_responses[-1])
            ]
            print(f"CoT Example {cot_formatted_prompts[0]}")
            responses = llm.generate(cot_formatted_prompts, sampling_params)
        else:
            responses = llm.generate(formatted_prompts, sampling_params)
        
        clean_answers = [
            response.outputs[0].text.replace("assistant\n\n", "").strip().rsplit(".", 1)[0] + "." 
            for response in responses
        ]
        all_responses.append(clean_answers)
    
    return list(zip(*all_responses))

def main():
    args = parse_arguments()
    
    # Initialize Ray
    ray.init()

    # Load input data
    with open(args.input_file, 'r') as input_file:
        all_records = json.load(input_file)
    
    all_prompts = [example['synthesized_prompt'] for example in all_records]
    all_personas = [example['input persona'][1:] if example['input persona'].startswith('A') else example['input persona'] for example in all_records]

    num_gpus = torch.cuda.device_count() if ("72B" not in args.model_id) and ("70B" not in args.model_id) else 1
    batch_size = (len(all_prompts) + num_gpus - 1) // num_gpus
    batches = [(all_prompts[i:i + batch_size], all_personas[i:i + batch_size]) for i in range(0, len(all_prompts), batch_size)]

    # Generate responses in parallel
    futures = [generate_responses(batch[0], batch[1], args.model_id, args.num_responses) for batch in batches]
    chunked_responses = futures

    # Flatten the list of responses
    all_responses = [response for chunk in chunked_responses for response in chunk]

    # Add responses to each record
    for idx, (record, responses) in enumerate(zip(all_records, all_responses), 1):
        for i, response in enumerate(responses, 1):
            record[f'synthesized_response_{i}'] = response
        record['id'] = idx

    # Write the output data to a file
    with open(args.output_file, 'w') as output_file:
         json.dump(all_records, output_file, ensure_ascii=False, indent=2)

    # Shutdown Ray
    ray.shutdown()

if __name__ == "__main__":
    main()
