# /path/to/your_script.py

import argparse
import json
import os
from typing import List, Dict
import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import ray

def parse_arguments():
    parser = argparse.ArgumentParser(description="Generate multiple responses for prompts using vLLM")
    parser.add_argument("--model_id", type=str, default="NousResearch/Meta-Llama-3-8B-Instruct", help="Model ID to use")
    parser.add_argument("--input_file", type=str, default="prompt jsonl file", help="Path to input JSONL file")
    parser.add_argument("--output_file", type=str, default="response jsonl file", help="Path to output JSONL file")
    parser.add_argument("--num_responses", type=int, default=2, help="Number of responses to generate for each prompt")
    return parser.parse_args()

@ray.remote(num_gpus=1)
def generate_responses(prompts: List[str], model_id: str, num_responses: int) -> List[List[str]]:
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    llm = LLM(model=model_id, enforce_eager=True)
    
    sampling_params = SamplingParams(
        temperature=0.6,
        top_p=0.95,
        max_tokens=1024,
        n=num_responses,
        stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
    )
    
    messages = [[{"role": "user", "content": prompt}] for prompt in prompts]
    formatted_prompts = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages]
    
    responses = llm.generate(formatted_prompts, sampling_params)
    
    all_responses = []
    for prompt_responses in responses:
        clean_answers = [output.text.replace("assistant\n\n", "").strip().rsplit(".", 1)[0] + "." for output in prompt_responses.outputs]
        all_responses.append(clean_answers)
    
    return all_responses

def main():
    args = parse_arguments()
    if "gemma" in args.model_id:
        os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER"
    
    ray.init()

    with open(args.input_file, 'r') as input_file:
        all_records = json.load(input_file)
    
    all_prompts = [example['synthesized_prompt'] for example in all_records]

    num_gpus = 1 if ("72B" in args.model_id or "70B" in args.model_id) else torch.cuda.device_count()
    batch_size = (len(all_prompts) + num_gpus - 1) // num_gpus
    batches = [all_prompts[i:i + batch_size] for i in range(0, len(all_prompts), batch_size)]

    futures = [generate_responses.remote(batch, args.model_id, args.num_responses) for batch in batches]
    chunked_responses = ray.get(futures)

    all_responses = [response for chunk in chunked_responses for response in chunk]

    for idx, (record, responses) in enumerate(zip(all_records, all_responses), 1):
        for i, response in enumerate(responses, 1):
            record[f'synthesized_response_{i}'] = response
        record['id'] = idx

    with open(args.output_file, 'w') as output_file:
        json.dump(all_records, output_file, ensure_ascii=False, indent=2)

    ray.shutdown()

if __name__ == "__main__":
    main()