import argparse
import json
import re
import os
import random
import time
import numpy as np
import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import ray

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def format_dataset(record, prompt_template):
    prompt = record.get('synthesized_prompt', '')
    response_1 = record.get('synthesized_response_1', '')
    response_2 = record.get('synthesized_response_2', '')
    formatted_prompt = prompt_template.format(
        prompt=prompt,
        response_1=response_1,
        response_2=response_2
    )
    return [(record, formatted_prompt)]

@ray.remote(num_gpus=1)
def process_batch(prompts, model_id, tokenizer):
    llm = LLM(model=model_id, tensor_parallel_size=1,enforce_eager=True)

    def do_sample(prompts):
        sampling_params = SamplingParams(
            temperature=0.7,
            top_p=0.9,
            max_tokens=100,  # Adjust max tokens if needed
            stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
        )
        messages = [[{"role": "user", "content": p}] for p in prompts]
        formatted_prompts = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)  
        responses = llm.generate(formatted_prompts, sampling_params)
        output = [x.outputs[0].text for x in responses]
        return output

    answers = do_sample([prompt[1] for prompt in prompts])
    results = []
    pattern = r"ranking:\s*([12])\s*>\s*([12])"
    for i, (record, _) in enumerate(prompts):
        answer = answers[i]
        matches = re.findall(pattern, answer, re.IGNORECASE)
        if matches:
            ranking = matches[0]
            if len(set(ranking)) == 2:  # Ensure both options are present
                record['rank_order'] = " > ".join(ranking)
            else:
                print(f"Warning: Invalid ranking (duplicate options) for record {record['id']} - {ranking}")
        else:
            print(f"Warning: No valid ranking found for record {record['id']} - {answer}")
        results.append(record)

    return results

def main(args):
    ray.init()

    set_seed(args.seed)

    with open(args.prompt_template_path, 'r') as file:
        prompt_template = file.read()

    with open(args.input_file_path, 'r') as input_file:
        all_records = json.load(input_file)

    for i, record in enumerate(all_records):
        record['id'] = i + 1

    formatted_dataset = []
    for record in all_records:
        formatted_dataset.extend(format_dataset(record, prompt_template))

    tokenizer = AutoTokenizer.from_pretrained(args.model_id)
    tokenizer.pad_token = tokenizer.eos_token

    num_gpus = torch.cuda.device_count()
    batch_size = (len(formatted_dataset) + num_gpus - 1) // num_gpus
    batches = [formatted_dataset[i:i + batch_size] for i in range(0, len(formatted_dataset), batch_size)]

    futures = [process_batch.remote(batch, args.model_id, tokenizer) for batch in batches]
    results = ray.get(futures)

    results = [item for sublist in results for item in sublist]

    unique_results = {record['id']: record for record in results}
    final_results = [record for record in unique_results.values() if 'rank_order' in record]

    for record in final_results:
        if 'rank_order' in record:
            ranks = record['rank_order'].split(" > ")
            rank_to_key = {'1': 'synthesized_response_1', '2': 'synthesized_response_2'}
            if all(rank in rank_to_key for rank in ranks):
                record['chosen'] = record[rank_to_key[ranks[0]]]
                record['rejected'] = record[rank_to_key[ranks[-1]]]
            else:
                print(f"Warning: Invalid rank order for record {record['id']} - {ranks}")
        else:
            print(f"Warning: No rank order for record {record['id']}")

    # Ensure directory exists
    os.makedirs(os.path.dirname(args.output_file), exist_ok=True)

    # Save results to the specified output file
    with open(args.output_file, 'w') as output_file:
        json.dump(final_results, output_file, ensure_ascii=False, indent=2)

    timediff = time.time() - start
    print(f"Time elapsed: {timediff:.2f} seconds")
    print(f"Total records: {len(final_results)}")
    print(f"Records with valid ranking: {sum(1 for r in final_results if 'rank_order' in r)}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Parallel Scoring Script with Ray and vllm")
    parser.add_argument("--input_file_path", type=str, default="response jsonl file", help="Path to input JSON file")
    parser.add_argument("--output_file", type=str, default="score.jsonl", help="File path to save the output JSON file")
    parser.add_argument("--prompt_template_path", type=str, default="./code/llm_as_a_judge_pairwise_prompt.txt", help="Path to the prompt template file")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for shuffling the dataset")
    parser.add_argument("--model_id", type=str, default="NousResearch/Meta-Llama-3-8B-Instruct", help="ID of the model to use")
    args = parser.parse_args()
    start = time.time()
    if "gemma" in args.input_file_path:
        os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER"
    main(args)
