import json
from typing import List

import argparse
import os
from dotenv import load_dotenv
load_dotenv()

from evaluation.prompts import qa_eval_prompt
from openai import OpenAI
from datasets import load_dataset

class GAIAEvaluator:
    def __init__(self, level: int, split: str, client:  OpenAI, evaluator_model: str = "deepseek/deepseek-chat-v3.1:free"):
        self.level = level
        self.split = split
        self.evaluator_model = evaluator_model
        self.client = client

    def evaluate_single_question(self, question: str, gt_answer: str, pred_answer: str):
        prompt = qa_eval_prompt.format(
            question=question,
            labeled_answer=gt_answer,
            pred_answer=pred_answer
        )

        while True:
            response = self.client.chat.completions.create(
                model=self.evaluator_model,
                messages=[
                    {"role": "user", "content": prompt}
                ]
            )

            if "Correct" in response.choices[0].message.content:
                return True
            if "Incorrect" in response.choices[0].message.content:
                return False
            
    def evaluate_complete_result(self, result_json_path: str) -> List[bool]:
        # load dataset
        dataset = load_dataset("./datasets/GAIA/GAIA.py", name=f"2023_level{self.level}", data_dir=".", split=self.split, trust_remote_code=True)


        # load predictions
        with open(result_json_path, 'r') as f:
            result_json = json.load(f)

        result = []

        for item in result_json:
            id = item['id']
            question = item['question']
            print(f"Evaluating task: {id}\nQuestion: {question}")
            
            # find gt answer based on task_id
            gt_answer = next((d['Final answer'] for d in dataset if d['task_id'] == id), None)
            if 'answer' in item.keys():
                flag = self.evaluate_single_question(question, gt_answer, item['answer'])
            else:
                continue
            # flag = self.evaluate_single_question(task_id, gt_answer, item['step_by_step_results'][-1]['final_answer'])
            if flag:
                print(f"Task {id} passed.")
            else:
                print(f"Task {id} failed.")

            result.append(flag)

        return result
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--level", type=int, default=1)
    parser.add_argument("--split", type=str, default="validation")
    parser.add_argument("--result_json_path", type=str, required=True)
    args = parser.parse_args()

    client = OpenAI(
                base_url="https://openrouter.ai/api/v1",
                api_key=os.getenv("OPENROUTER_API_KEY"),
            )

    evaluator = GAIAEvaluator(level=args.level, split=args.split, client=client)

    result = evaluator.evaluate_complete_result(args.result_json_path)

    print(f"Final accuracy: {sum(result) / len(result)}")
