# Copyright (c) 2024 metauto.ai
# 
# This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
# The modified version is used for comparaing the performance of GPTSwarm with the performance of KGoT in GAIA Dataset.


import os
import argparse
from tqdm import tqdm
import json
import time
import asyncio
import contextlib


from swarm.environment.agents.gaia.tool_tot import ToolTOT
from swarm.utils.globals import Cost, CompletionTokens, PromptTokens, UsageStatisticsObject
from swarm.utils.log import swarmlog, configure_logging

from utils.gaia_scorer import check_close_call, question_scorer
from utils.log_and_statistics import UsageStatistics

async def main():
    parser = argparse.ArgumentParser(description="GPTSwarm Experiments on GAIA")
    parser.add_argument('--gaia_file', type=str, required=True, help='Path to GAIA JSON file')
    parser.add_argument('--attachment_folder', type=str, required=False, help='Path to GAIA problems attachments folder', default="GAIA/dataset/attachments/validation")
    parser.add_argument('--log_folder_base', type=str, required=True, help='Base folder for logging results')
    parser.add_argument("--llm_model", type=str, default="gpt-4o-mini")
    parser.add_argument('--max_iterations', type=int, required=False, help='Max iterations for GPTSwarm', default=3)
    parser.add_argument('--config_llms_path', type=str, required=False, help='Path to LLM configuration file', default="kgot/config_llms.json")
    parser.add_argument('--config_tools_path', type=str, required=False, help='Path to tools configuration file', default="kgot/config_tools.json")
    args = parser.parse_args()
    
    with open(args.gaia_file, 'r') as file:
        gaia_data = json.load(file)
    
    # KGoT GAIA Logic
    already_solved = 0
    if os.path.exists(args.log_folder_base):
        try:
            with open(os.path.join(args.log_folder_base, "correct_stats.json"), 'r') as f:
                results = json.load(f)
                already_solved = len(results)
                total_questions = len(gaia_data['rows']) if isinstance(gaia_data, dict) else len(gaia_data)
                print(f"Total questions: {total_questions}")
                if already_solved == total_questions:
                    already_solved = 0
                    print("\033[4;32m\033[1mAll questions already solved. Skipping...\033[0m")
                    exit(0)

                print(f"\033[4;32m\033[1mAlready solved {already_solved} questions. Starting from {already_solved + 1}...\033[0m")
        except FileNotFoundError:
            pass

    # Configure the LLM and tools
    if args.llm_model.startswith("gpt"):
        with open(args.config_llms_path, 'r') as file:
            llm_config = json.load(file)
        llm_model_config = llm_config[args.llm_model]
        os.environ["OPENAI_API_KEY"] = llm_model_config["api_key"]
        os.environ["OPENAI_ORGANIZATION"] = llm_model_config["organization"]
        os.environ["MULTIMODE_OPENAI_API_KEY"] = llm_model_config["api_key"]
    else:
        with open(args.config_llms_path, 'r') as file:
            llm_config = json.load(file)
        llm_model_config = llm_config["gpt-4o-mini"]
        os.environ["OPENAI_API_KEY"] = "ollama"
        os.environ["MULTIMODE_OPENAI_API_KEY"] = llm_model_config["api_key"]

        

    with open(args.config_tools_path, 'r') as file:
        tools_config = json.load(file)
    for tool in tools_config:
        for key, value in tool["env"].items():
            if value == "":
                os.environ[key] = ""
            else:
                os.environ[key] = value
    
    

    log_folder = args.log_folder_base
    os.makedirs(log_folder, exist_ok=True)

    cmd_log = os.path.join(log_folder, "cmd_log.log")
    log_file = os.path.join(log_folder, "output.log")
    log_file_correct_stats = os.path.join(log_folder, "correct_stats.json")
    llm_cost_json_file = os.path.join(log_folder, "llm_cost.json")
    llm_cost_json_file_total = os.path.join(log_folder, "llm_cost_total.json")
    usage_statistics = UsageStatistics(llm_cost_json_file)
    UsageStatisticsObject.instance(usage_statistics=usage_statistics)

    configure_logging(log_file_path=log_file)
    
    with open(cmd_log, 'a') as redirected_stdout:
        with contextlib.redirect_stdout(redirected_stdout):  # redirect stdout to log file

            ####################################
            print(args.llm_model)
            agent = ToolTOT(domain="gaia", model_name=args.llm_model)
            agent.display()
            ####################################
            for row in tqdm(gaia_data['rows'][already_solved:] if isinstance(gaia_data, dict) else gaia_data[already_solved:], desc="Processing questions", unit="question"):
                
                row_idx = row['row_idx']

                question = row['row']['Question']
                final_answer = row['row']['Final answer']
                file_name = row['row']['file_name']
                file_path = args.attachment_folder  # row['row']['file_path'] not used as attachments have been downloaded locally
                level = row['row']['Level']
                num_steps = row['row']['Annotator Metadata'].get('Number of steps', '')
                tools = row['row']['Annotator Metadata'].get('Tools', '')
                num_tools = row['row']['Annotator Metadata'].get('Number of tools', '')

                start_time = time.time()
                files = [os.path.join(file_path, file_name)] if file_name else file_name
                inputs = {"task": question, "files": files, "GT": final_answer}

                swarmlog("🐝GPTSWARM SYS", f"Solving question {row_idx}", Cost.instance().value, PromptTokens.instance().value, CompletionTokens.instance().value, log_file)
                
                print(inputs)

                answer, tries = await agent.run(inputs=inputs, max_tries=args.max_iterations)
                
                answer = answer[-1].split("FINAL ANSWER: ")[-1]

                end_time = time.time()
                exe_time =  end_time - start_time


                print("-----")
                print(f"AGENT ANSWER: {answer}")
                print("-----")

                successful = question_scorer(answer, final_answer)
                close_call = check_close_call(answer, final_answer, successful)
                
                if successful:
                    print(f"Row {row_idx}: Correct (Expected: {final_answer}, Got: {answer})", flush=True)
                elif close_call:
                    print(f"Row {row_idx}: Close Call (Expected: {final_answer}, Got: {answer})", flush=True)
                else:
                    print(f"Row {row_idx}: Incorrect (Expected: {final_answer}, Got: {answer})", flush=True)


                result = {
                    "question_number": row_idx,
                    "correct_answer": final_answer,
                    "returned_answer": answer,
                    "successful": successful,
                    "close_call": close_call,
                    "level": level,
                    "iterations_taken": tries+1,
                    "num_steps": num_steps,
                    "tools": tools,
                    "num_tools": num_tools,
                    "GPTSwarm_official_timer": exe_time,
                    "GPTSwarm_official_cost_accumulated": Cost.instance().value,
                }
                try:
                    with open(log_file_correct_stats, 'r') as output_file:
                        results = json.load(output_file)
                except FileNotFoundError:
                    results = []
                results.append(result)

                # Write the updated results back to the file
                with open(log_file_correct_stats, 'w') as output_file:
                    json.dump(results, output_file, indent=4)
                    
            with open(log_file_correct_stats, 'r') as output_file:
                results = json.load(output_file)

            total_questions = len(results)
            correct_answers = sum(1 for result in results if result['close_call'])

            if total_questions > 0:
                percentage_correct = (correct_answers / total_questions) * 100
                print(f"\nTotal questions: {total_questions}")
                print(f"Correct answers: {correct_answers}")
                print(f"Percentage correct: {percentage_correct:.2f}%")
            else:
                print("No questions to evaluate based on the provided filter.")
                
            UsageStatistics.calculate_total_cost(llm_cost_json_file, llm_cost_json_file_total)
                

if __name__ == '__main__':
    asyncio.run(main())
