import json
from typing import List

import argparse
import os
from dotenv import load_dotenv
load_dotenv()

from evaluation.prompts import qa_eval_prompt
from openai import OpenAI
from benchmarks.GPQA_Diamond.loader import GPQADataset

class GPQAEvaluator:
    def __init__(self, subject: str, client: OpenAI, evaluator_model: str = "deepseek/deepseek-chat-v3.1:free"):
        self.subject = subject
        self.client = client
        self.evaluator_model = evaluator_model

    def evaluate_single_question(self, question: str, gt_answer: str, pred_answer: str):
        prompt = qa_eval_prompt.format(
            question=question,
            labeled_answer=gt_answer,
            pred_answer=pred_answer
        )

        while True:
            response = self.client.chat.completions.create(
                model=self.evaluator_model,
                messages=[
                    {"role": "user", "content": prompt}
                ]
            )

            if "Correct" in response.choices[0].message.content:
                return True
            if "Incorrect" in response.choices[0].message.content:
                return False

    def evaluate_complete_result(self, result_json_path: str) -> List[bool]:
        # load predictions
        with open(result_json_path, 'r') as f:
            result_json = json.load(f)
        if self.subject == "All":
            dataset = GPQADataset().get_full_set()
        else:
            dataset = GPQADataset().get_by_subject(self.subject)

        result = []

        for item in result_json:
            id = item['id']
            question = item['question']
            pred_answer = item['answer']
            gt_answer = next((d['answer'] for d in dataset if d['id'] == id), None)
            flag = self.evaluate_single_question(question, gt_answer, pred_answer)
            if flag:
                print(f"Task {question} passed.")
            else:
                print(f"Task {question} failed.")

            result.append(flag)
        return result

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--subject", type=str, default="All")
    parser.add_argument("--result_json_path", type=str, required=True)

    args = parser.parse_args()

    client = OpenAI(
                base_url="https://openrouter.ai/api/v1",
                api_key=os.getenv("OPENROUTER_API_KEY"),
            )

    evaluator = GPQAEvaluator(subject=args.subject, client=client)

    result = evaluator.evaluate_complete_result(args.result_json_path)

    print(f"Final accuracy: {sum(result) / len(result)}")