import argparse
import json
import logging
import os
import re
import traceback
from typing import Any, Dict, List

from dotenv import load_dotenv


import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from agent.Agent_async import Actor
from tools.unified_tool import UnifiedTool, ToolRegistry
from tools.base import Configuration, Server
from gaia_utils import (
    add_file_path,
    load_dataset_meta,
    question_scorer,
    report_results,
)
import asyncio

# Create log directory if it doesn't exist
if not os.path.exists(os.getenv("LOG_FILE_PATH", "./logs")):
    os.makedirs(os.getenv("LOG_FILE_PATH", "./logs"))

parser = argparse.ArgumentParser()
parser.add_argument(
    "--start",
    type=int,
    default=0,
    help="Start index of the dataset",
)
parser.add_argument(
    "--end",
    type=int,
    default=20,
    help="End index of the dataset",
)
parser.add_argument(
    "--q",
    type=str,
    help="Question Index, e.g., 0-0-0-0-0. Highest priority: override other arguments if provided.",
)
parser.add_argument(
    "--skip",
    action="store_true",
    help="Skip the question if it has been processed before.",
)
parser.add_argument(
    "--split",
    type=str,
    default="validation",
    help="Split of the dataset, e.g., validation, test",
)
parser.add_argument(
    "--blacklist_file_path",
    type=str,
    nargs="?",
    help="Blacklist file path, e.g., blacklist.txt",
)
args = parser.parse_args()


def setup_logging():
    logging_logger = logging.getLogger()
    logging_logger.setLevel(logging.INFO)

    log_file_name = (
        f"/super_agent_{args.q}.log"
        if args.q
        else f"/super_agent_{args.start}_{args.end}.log"
    )
    file_handler = logging.FileHandler(
        os.getenv(
            "LOG_FILE_PATH",
            "./logs",
        )
        + log_file_name,
        mode="a",
        encoding="utf-8",
    )
    file_handler.setLevel(logging.INFO)

    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    file_handler.setFormatter(formatter)

    logging_logger.addHandler(file_handler)


async def main():
    load_dotenv(override=True)
    setup_logging()

    gaia_dataset_path = os.getenv("GAIA_DATASET_PATH", "C:/Users/87515/Documents/PyProject/AWorld/gaia")
    full_dataset = load_dataset_meta(gaia_dataset_path, split=args.split)
    logging.info(f"Total questions: {len(full_dataset)}")


    ######################################################### tools register #########################################################
    tool_registry = ToolRegistry()

    config = Configuration()
    server_config = config.load_config("gaia/mcp.json")

    servers = [
        Server(name, srv_config)
        for name, srv_config in server_config["mcpServers"].items()
    ]

    for server in servers:
        await server.initialize()

    # Register MCP server tools
    for server in servers:
        await tool_registry.register_server_tools(server)
    ######################################################### MultiAgentSystem #########################################################


    from gaia_flow import MultiAgentSystem
    flow = MultiAgentSystem(tools=tool_registry.get_all_tools())



    ######################################################### load results from the checkpoint file #########################################################
    if os.path.exists(os.getenv("LOG_FILE_PATH", "./logs") + "/results.json"):
        with open(
            os.getenv("LOG_FILE_PATH", "./logs") + "/results.json", "r", encoding="utf-8"
        ) as results_f:
            results: List[Dict[str, Any]] = json.load(results_f)
    else:
        results: List[Dict[str, Any]] = []

    # load blacklist `task_id`
    if args.blacklist_file_path and os.path.exists(args.blacklist_file_path):
        with open(args.blacklist_file_path, "r", encoding="utf-8") as f:
            blacklist = set(f.read().splitlines())
    else:
        blacklist = set()  # Empty set if file doesn't exist

    ######################################################### main loop #########################################################
    try:
        # slice dataset by args.start and args.end, overrided by args.q (single `task_id`)
        dataset_slice = (
            [
                dataset_record
                for idx, dataset_record in enumerate(full_dataset)
                if dataset_record["task_id"] in args.q
            ]
            if args.q is not None
            else full_dataset[args.start : args.end]
        )

        # main loop to execute questions
        for i, dataset_i in enumerate(dataset_slice):
            # specify `task_id`
            if args.q and args.q != dataset_i["task_id"]:
                continue
            # only valid for args.q==None
            if not args.q:
                # blacklist
                if dataset_i["task_id"] in blacklist:
                    continue

                # pass
                if any(
                    # Question Done and Correct
                    (result["task_id"] == dataset_i["task_id"] and result["is_correct"])
                    for result in results
                ) or any(
                    # Question Done and Incorrect, but Level is 3
                    (
                        result["task_id"] == dataset_i["task_id"]
                        and not result["is_correct"]
                        and dataset_i["Level"] == 3
                    )
                    for result in results
                ):
                    continue

                # skip
                if args.skip and any(
                    # Question Done and Correct
                    (result["task_id"] == dataset_i["task_id"])
                    for result in results
                ):
                    continue

            # run
            try:
                logging.info(f"Start to process: {dataset_i['task_id']}")
                logging.info(f"Detail: {dataset_i}")
                logging.info(f"Question: {dataset_i['Question']}")
                logging.info(f"Level: {dataset_i['Level']}")
                logging.info(f"Tools: {dataset_i['Annotator Metadata']['Tools']}")

                question = add_file_path(
                    dataset_i, file_path=gaia_dataset_path, split=args.split
                )["Question"]

                task = question
                result = await flow.run(task)

                answer = result['answer']
                is_correct = False

                if answer:
                    logging.info(f"Agent answer: {answer}")
                    logging.info(f"Correct answer: {dataset_i['Final answer']}")
                    is_correct = question_scorer(answer, dataset_i["Final answer"])
                    if is_correct:
                        logging.info(f"Question {i} Correct!")
                    else:
                        logging.info("Incorrect!")

                # Create the new result record
                new_result = {
                    "task_id": dataset_i["task_id"],
                    "level": dataset_i["Level"],
                    "question": question,
                    "answer": dataset_i["Final answer"],
                    "response": answer,
                    "is_correct": is_correct,
                }

                # Check if this task_id already exists in results
                existing_index = next(
                    (
                        i
                        for i, result in enumerate(results)
                        if result["task_id"] == dataset_i["task_id"]
                    ),
                    None,
                )

                if existing_index is not None:
                    # Update existing record
                    results[existing_index] = new_result
                    logging.info(
                        f"Updated existing record for task_id: {dataset_i['task_id']}"
                    )
                else:
                    # Append new record
                    results.append(new_result)
                    logging.info(
                        f"Added new record for task_id: {dataset_i['task_id']}"
                    )

            except Exception as e:
                logging.error(f"Error processing {i}: {traceback.format_exc()}")
                continue
    except KeyboardInterrupt:
        pass
    finally:
        # report
        report_results(results)
        with open(
            os.getenv("LOG_FILE_PATH", "./logs") + "/results.json", "w", encoding="utf-8"
        ) as f:
            json.dump(results, f, indent=4, ensure_ascii=False)
        for server in servers:
            await server.cleanup()

if __name__ == "__main__":
    asyncio.run(main())

