import openai
import re
import json
import os
import itertools
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import threading
import time


INPUT_FILE_PATH = "xxx"
OUTPUT_FILE_PATH = "xxx"
ERROR_LOG_PATH = "xxx"

API_BASE_URL = "xxx"
API_KEY = "xxx" 
MODEL_NAME = "xxx"
MAX_TOKENS = 8000
TEMPERATURE = 0.2

MAX_RETRIES = 5 
CHECKPOINT_INTERVAL = 3  
MAX_WORKERS = 16  


SYSTEM_PROMPT = """Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user's instructions and answers the user's question better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. Do not provide your explanation, only output your final verdict by strictly following this format: "[[A]]" if assistant A is better, "[[B]]" if assistant B is better, and "[[C]]" for a tie."""

USER_PROMPT_TEMPLATE = """[User Question]
{question}

[The Start of Assistant A's Answer]
{answer_a}
[The End of Assistant A's Answer]

[The Start of Assistant B's Answer]
{answer_b}
[The End of Assistant B's Answer]

Remember only output "[[A]]" or "[[B]]" or "[[C]]" without any explanation. Output "[[A]]" if assistant A is better, "[[B]]" if assistant B is better, and "[[C]]" for a tie."""



file_lock = threading.Lock()


def load_json_data(filepath):

    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            return json.load(f)
    except (FileNotFoundError, json.JSONDecodeError):
        print(f"Warning: Could not load or find the file {filepath}. Starting from scratch.")
        return []


def load_or_initialize_results(filepath):
    existing_data = load_json_data(filepath)
    processed_ids = {item['id'] for item in existing_data}
    return existing_data, processed_ids


def save_results(data, filepath):
    with file_lock:
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(data, f, indent=4, ensure_ascii=False)


def log_error_message(sample_id, id1, id2, attempt_info, reason):
    error_line = f"sample_id: {sample_id}, response_ids: ({id1}, {id2}), attempt: {attempt_info}, reason: {reason}"
    with file_lock:
        with open(ERROR_LOG_PATH, 'a', encoding='utf-8') as f:
            f.write(error_line + '\n')
    print(f"ERROR logged: {error_line}")

def parse_llm_verdict(llm_output, resp_a_id, resp_b_id):
    if "[[A]]" in llm_output:
        return resp_a_id
    elif "[[B]]" in llm_output:
        return resp_b_id
    elif "[[C]]" in llm_output:
        return -1  # Tie
    else:
        return -2  # Failed to parse

def run_single_comparison(client, sample_id, question, resp_a, resp_b, comparison_num):
    resp_a_id = resp_a['response_id']
    resp_b_id = resp_b['response_id']

    user_prompt = USER_PROMPT_TEMPLATE.format(
        question=question,
        answer_a=resp_a['response_content'],
        answer_b=resp_b['response_content']
    )

    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_prompt}
    ]
    
    for attempt in range(MAX_RETRIES):
        try:
            chat_completion = client.chat.completions.create(
                model=MODEL_NAME,
                messages=messages,
                max_tokens=MAX_TOKENS,
                #temperature=TEMPERATURE
            )
            
            raw_content = chat_completion.choices[0].message.content
            cleaned_content = re.sub(r"<think>.*?</think>\s*\n", "", raw_content, flags=re.DOTALL).strip()
            
            result = parse_llm_verdict(cleaned_content, resp_a_id, resp_b_id)

            print(f"  - Compare (A:{resp_a_id}, B:{resp_b_id}): Model response='{cleaned_content}', Parsed result={result}")

            if result == -2: 
                if attempt < MAX_RETRIES - 1:
                    print(f"    -> Parsing failed. Retrying ({attempt + 1}/{MAX_RETRIES})...")
                    time.sleep(1)
                    continue
                else:
                    log_error_message(sample_id, resp_a_id, resp_b_id, f"Comparison {comparison_num}", f"Failed to parse model output after {MAX_RETRIES} retries. Final output: '{cleaned_content}'")
                    return {"a_id": resp_a_id, "b_id": resp_b_id, "winner": -2}

            return {"a_id": resp_a_id, "b_id": resp_b_id, "winner": result}

        except Exception as e:
            print(f"  - Compare (A:{resp_a_id}, B:{resp_b_id}): API call failed on attempt {attempt + 1}. Error: {e}")
            if attempt >= MAX_RETRIES - 1:
                log_error_message(sample_id, resp_a_id, resp_b_id, f"Comparison {comparison_num}", f"API call failed after {MAX_RETRIES} attempts: {e}")
                return {"a_id": resp_a_id, "b_id": resp_b_id, "winner": -2}
            time.sleep(2) 

    return {"a_id": resp_a_id, "b_id": resp_b_id, "winner": -2} 

def main():
    
    client = openai.OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
    print(f"Client initialized for model at {API_BASE_URL}")
    all_samples = load_json_data(INPUT_FILE_PATH)
    if not all_samples:
        print("Error: Input file is empty or could not be loaded. Exiting.")
        return
        
    results_data, processed_ids = load_or_initialize_results(OUTPUT_FILE_PATH)
    print(f"Found {len(all_samples)} total samples. {len(processed_ids)} already processed.")

    for i, sample in enumerate(tqdm(all_samples, desc="Processing Samples")):
        if sample['id'] in processed_ids:
            continue

        print(f"\nProcessing sample ID: {sample['id']}")
        
        sorted_responses = sorted(sample['responses'], key=lambda x: x['response_id'])
        
        response_pairs = list(itertools.combinations(sorted_responses, 2))
        
        tasks = []
        for r1, r2 in response_pairs:
            tasks.append({'resp_a': r1, 'resp_b': r2, 'comparison_num': 1})
            tasks.append({'resp_a': r2, 'resp_b': r1, 'comparison_num': 2})

        comparison_outcomes = {}

        with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
            future_to_task = {
                executor.submit(
                    run_single_comparison,
                    client,
                    sample['id'],
                    sample['prompt'],
                    task['resp_a'],
                    task['resp_b'],
                    task['comparison_num']
                ): task for task in tasks
            }
            
            for future in as_completed(future_to_task):
                try:
                    result = future.result()
                    pair_key = tuple(sorted((result['a_id'], result['b_id'])))
                    if pair_key not in comparison_outcomes:
                        comparison_outcomes[pair_key] = {}
                    comparison_outcomes[pair_key][result['a_id']] = result['winner']
                except Exception as exc:
                    print(f"An unexpected error occurred in a thread: {exc}")

        sample['compare_results'] = []
        for r1, r2 in response_pairs:
            id1 = r1['response_id']
            id2 = r2['response_id']
            pair_key = (id1, id2)
            
            res1 = comparison_outcomes.get(pair_key, {}).get(id1, -2)
            res2 = comparison_outcomes.get(pair_key, {}).get(id2, -2)

            sample['compare_results'].append({
                "response_id_1": id1,
                "response_id_2": id2,
                "result": [res1, res2]
            })

        results_data.append(sample)

        processed_count = len(results_data) - len(processed_ids)
        if processed_count > 0 and processed_count % CHECKPOINT_INTERVAL == 0:
            save_results(results_data, OUTPUT_FILE_PATH)

    
    save_results(results_data, OUTPUT_FILE_PATH)


if __name__ == "__main__":
    main()