import argparse
import json
import re
import os
import random
import time
import numpy as np
import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from multiprocessing import Pool, set_start_method

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def format_dataset(record, prompt_template):
    prompt = record.get('synthesized_prompt', '')
    response_1 = record.get('synthesized_response_1', '')
    response_2 = record.get('synthesized_response_2', '')
    formatted_prompt = prompt_template.format(
        prompt=prompt,
        response_1=response_1,
        response_2=response_2
    )
    return [(record, formatted_prompt)]

def process_batch(prompts, model_id, tokenizer):
    llm = LLM(model=model_id, tensor_parallel_size=4)

    def do_sample(prompts):
        sampling_params = SamplingParams(
            temperature=0.7,
            top_p=0.9,
            max_tokens=100,
            stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("")]
        )
        messages = [[{"role": "user", "content": p}] for p in prompts]
        formatted_prompts = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        responses = llm.generate(formatted_prompts, sampling_params)
        output = [x.outputs[0].text for x in responses]
        return output

    answers = do_sample([prompt[1] for prompt in prompts])
    results = []
    pattern = r"ranking:\s*([12])\s*>\s*([12])"
    for i, (record, _) in enumerate(prompts):
        answer = answers[i]
        matches = re.findall(pattern, answer, re.IGNORECASE)
        if matches:
            ranking = matches[0]
            if len(set(ranking)) == 2:  # Ensure both options are present
                record['70B_rank_order'] = " > ".join(ranking)
            else:
                print(f"Warning: Invalid ranking (duplicate options) for record {record['id']} - {ranking}")
        else:
            print(f"Warning: No valid ranking found for record {record['id']} - {answer}")
        results.append(record)

    return results

def main(args):
    set_seed(args.seed)

    start = time.time()

    with open(args.prompt_template_path, 'r') as file:
        prompt_template = file.read()

    with open(args.input_file_path, 'r', encoding='utf-8') as input_file:
        all_records = json.load(input_file)

    # Limit the dataset to the specified number of records
    if args.max_records > 0:
        all_records = all_records[:args.max_records]


    for i, record in enumerate(all_records):
        record['id'] = i + 1

    formatted_dataset = []
    for record in all_records:
        formatted_dataset.extend(format_dataset(record, prompt_template))

    tokenizer = AutoTokenizer.from_pretrained(args.model_id)
    tokenizer.pad_token = tokenizer.eos_token

    # Processing the batch sequentially to avoid CUDA multiprocessing issues
    results = process_batch(formatted_dataset, args.model_id, tokenizer)

    unique_results = {record['id']: record for record in results}
    final_results = [record for record in unique_results.values() if '70B_rank_order' in record]

    for record in final_results:
        if '70B_rank_order' in record:
            ranks = record['70B_rank_order'].split(" > ")
            rank_to_key = {'1': 'synthesized_response_1', '2': 'synthesized_response_2'}
            if all(rank in rank_to_key for rank in ranks):
                record['chosen'] = record[rank_to_key[ranks[0]]]
                record['rejected'] = record[rank_to_key[ranks[-1]]]
            else:
                print(f"Warning: Invalid rank order for record {record['id']} - {ranks}")
        else:
            print(f"Warning: No rank order for record {record['id']}")

    os.makedirs(args.output_dir, exist_ok=True)


    input_dir_name = os.path.basename(os.path.dirname(os.path.dirname(args.input_file_path)))
    output_file_name = f"{args.model_id.replace('/', '_')}-{input_dir_name}-{args.max_records}.json"
    output_file_path = os.path.join(args.output_dir, output_file_name)

    with open(output_file_path, 'w', encoding='utf-8') as output_file:
        json.dump(final_results, output_file, ensure_ascii=False, indent=2)

    timediff = time.time() - start
    print(f"Time elapsed: {timediff:.2f} seconds")
    print(f"Total records: {len(final_results)}")
    print(f"Records with valid ranking: {sum(1 for r in final_results if '70B_rank_order' in r)}")

if __name__ == "__main__":
    set_start_method('spawn')  # Use 'spawn' to avoid CUDA issues
    parser = argparse.ArgumentParser(description="Parallel Scoring Script with Multiprocessing")
    parser.add_argument("--input_file_path", type=str, required=True, help="Path to input JSON file")
    parser.add_argument("--output_dir", type=str, default="score", help="Directory to save the output JSON file")
    parser.add_argument("--prompt_template_path", type=str, default="llm_as_a_judge_pairwise_prompt.txt", help="Path to the prompt template file")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for shuffling the dataset")
    parser.add_argument("--model_id", type=str, required=True, help="ID of the model to use")
    parser.add_argument("--max_records", type=int, default=10000, help="Maximum number of records to evaluate")
    args = parser.parse_args()
    main(args)
