# This is a modified version of the code available at https://github.com/centerforaisafety/hle/blob/main/hle_eval/run_judge_results.py.
import os
import json
import copy
import math
import argparse
import asyncio
import numpy as np
from typing import Literal
from pydantic import BaseModel
from openai import AsyncOpenAI
from tqdm.asyncio import tqdm_asyncio
from datasets import load_dataset
import utils
import pandas as pd

JUDGE_PROMPT = """Extract the final answers from the following [response A] and [response B] and compare them.

[response A]: {response_A}

[response B]: {response_B}

Your response must be in the format below:

extracted_final_answer_from_response_A: The final exact answer extracted from the [response A]. Put the extracted answer as 'None' if there is no exact, final answer to extract from the response A.

extracted_final_answer_from_response_B: The final exact answer extracted from the [response B]. Put the extracted answer as 'None' if there is no exact, final answer to extract from the response B.

correct: Answer 'yes' if extracted_final_answer_from_response_A matches the extracted_final_answer_from_response_B given above. The strings do not have to be an exact match, but they must indicate the same answers. Answer 'no' otherwise."""


class ExtractedAnswer(BaseModel):
    extracted_final_answer_from_response_A: str
    extracted_final_answer_from_response_B: str
    correct: Literal["yes", "no"]
    confidence: int
    strict: Literal[True] # 100% reliability

    
async def extract_answer(question, correct_answer, model_answer):
    prompt = JUDGE_PROMPT.format(response_A=model_answer, response_B=correct_answer)
    try:
        response = await client.beta.chat.completions.parse(
                model=args.judger,
                max_completion_tokens=4096, # overkill for judge
                messages=[
                    {"role": "user", "content": prompt}
                ],
                response_format=ExtractedAnswer, 
            ) 
        content = response.choices[0].message.parsed
        return { 
            "correct_answer": content.extracted_final_answer_from_response_B,
            "model_answer": content.extracted_final_answer_from_response_A,
            "correct": content.correct,
            "confidence": content.confidence,
            "full_model_answer": model_answer,
            "full_correct_answer": correct_answer,
            "question": question
        }
    except Exception as e: # very, very rare
        print("Error:", e)
        return None
        
async def add_judge_response(question, predictions):
    unique_id = question["id"]
    prediction = copy.deepcopy(predictions[unique_id]) # not in-place
    question_text = question["question"]
    correct_answer = question["answer"]

    if "judge_response" in prediction: # already judged
        return unique_id, prediction
    
    response = prediction["response"]
    content = await extract_answer(question_text, correct_answer, response)

    if content is not None:
        prediction["judge_response"] = content # local in-place
        return unique_id, prediction
    else:
        return None, None

async def judge_all_responses(questions, predictions):
    async def bound_func(question):
        async with semaphore:
            content = await add_judge_response(question, predictions)
            return content
            
    semaphore = asyncio.Semaphore(args.num_workers)
    async with semaphore:
        tasks = [bound_func(q) for q in questions]
        results = await tqdm_asyncio.gather(*tasks)
    return results


def compute_acc(dataset):
    correct = []
    for k, v in dataset.items():
        if "judge_response" in v:
            judge_response = v["judge_response"]
            if judge_response["model_answer"] == "None" or judge_response["correct_answer"] == "None":
                continue
            correct.append(1 if "yes" in judge_response["correct"] else 0)
        else:
            print(f"Missing judge response for {k}, you should rerun the judge")

    correct = np.array(correct)

    accuracy = sum(correct) / len(dataset)

    return accuracy, sum(correct), len(dataset), len(correct)


def main(args):
    assert args.num_workers > 1, "num_workers must be 2 or greater"

    # Load dataset
    if args.output_folder is not None:
        utils.ensure_dir(args.output_folder)
        output_filepath = os.path.join(args.output_folder, f"judged_{args.json_file_path.split('/')[-1]}")
    else:
        output_filepath = f"{'/'.join(args.json_file_path.split('/')[:-1])}/judged_{args.json_file_path.split('/')[-1]}"   

    dataset = load_dataset('json', data_files=args.json_file_path)['train']
    dataset = dataset.add_column('id', range(len(dataset)))

    # Prepare questions
    questions = dataset.select_columns(['id', 'question', 'answer'])
    total_questions = len(questions)

    # Prepared generated responses
    predictions = dataset.select_columns(['id', 'response'])
    
    # Begin judging
    judged_predictions = {}
    results = asyncio.run(judge_all_responses(questions, predictions)) # API will only be called for unjudged responses

    for unique_id, predictions in results:
        if unique_id is not None:
            judged_predictions[unique_id] = predictions

    # Cache judge output
    with open(output_filepath, "w") as f:
        json.dump(judged_predictions, f, indent=4)

    # Display accuracy
    acc, n_correct, n, n_valid = compute_acc(judged_predictions)
    print(f"Accuracy: {acc * 100:.2f}")

    # Save metrics
    df = pd.DataFrame({"Accuracy": [acc * 100], "N of correct predictions": [n_correct], "N": [n], "N of valid predictions": [n_valid]})
    df.to_csv(output_filepath.split(".json")[0] + ".csv")

if __name__ == "__main__":
    # Get arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--json_file_path", type=str, help="a path ot JSON file containing model predictions")
    parser.add_argument("--num_workers", type=int, default=100, help="Async semaphore size. This depends on your rate limit.")
    parser.add_argument("--judger", type=str, default="gpt-4o-mini", help="Judge model") # prev: "gpt-4o-2024-08-06"
    parser.add_argument("--output_folder", type=str, default=None, help="Output folder")
    args = parser.parse_args()

    # Initialize client
    client = AsyncOpenAI(timeout=300.0, max_retries=1)

    main(args)