# Copyright (c) 2024 metauto.ai
# 
# This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
# The modified version is used for comparaing the performance of GPTSwarm with the performance of KGoT in SimpleQA Dataset.


import os
import argparse
from tqdm import tqdm
import json
import time
import asyncio
import contextlib


from swarm.environment.agents.gaia.tool_tot import ToolTOT
from swarm.utils.globals import Cost, CompletionTokens, PromptTokens, UsageStatisticsObject
from swarm.utils.log import swarmlog, configure_logging


from utils.simple_qa_scorer import SimpleQAScorer
from utils.sampler.chat_completion_sampler import ChatCompletionSampler
from utils.log_and_statistics import UsageStatistics

from utils import common
from utils.simpleqa_types import SingleEvalResult


async def main():
    parser = argparse.ArgumentParser(description="GPTSwarm Experiments on SimpleQA")
    parser.add_argument('--simpleqa_file', type=str, required=True, help='Path to SimpleQA JSON file')
    parser.add_argument('--log_folder_base', type=str, required=True, help='Base folder for logging results')
    parser.add_argument("--llm_model", type=str, default="gpt-4o-mini")
    parser.add_argument('--max_iterations', type=int, required=False, help='Max iterations for GPTSwarm', default=3)
    parser.add_argument('--config_llms_path', type=str, required=False, help='Path to LLM configuration file', default="kgot/config_llms.json")
    parser.add_argument('--config_tools_path', type=str, required=False, help='Path to tools configuration file', default="kgot/config_tools.json")
    args = parser.parse_args()
    
    with open(args.simpleqa_file, 'r') as file:
        simpleqa_data = json.load(file)
    
    # KGoT SimpleQA Logic
    already_solved = 0
    if os.path.exists(args.log_folder_base):
        try:
            with open(os.path.join(args.log_folder_base, "correct_stats.json"), 'r') as f:
                results = json.load(f)
                already_solved = len(results)
                total_questions = len(simpleqa_data['rows']) if isinstance(simpleqa_data, dict) else len(simpleqa_data)
                print(f"Total questions: {total_questions}")
                if already_solved == total_questions:
                    already_solved = 0
                    print("\033[4;32m\033[1mAll questions already solved. Skipping...\033[0m")
                    exit(0)

                print(f"\033[4;32m\033[1mAlready solved {already_solved} questions. Starting from {already_solved + 1}...\033[0m")
        except FileNotFoundError:
            pass

    # Configure the LLM and tools
    if args.llm_model.startswith("gpt"):
        with open(args.config_llms_path, 'r') as file:
            llm_config = json.load(file)
        llm_model_config = llm_config[args.llm_model]
        os.environ["OPENAI_API_KEY"] = llm_model_config["api_key"]
        os.environ["OPENAI_ORGANIZATION"] = llm_model_config["organization"]
        os.environ["MULTIMODE_OPENAI_API_KEY"] = llm_model_config["api_key"]
    else:
        with open(args.config_llms_path, 'r') as file:
            llm_config = json.load(file)
        llm_model_config = llm_config["gpt-4o-mini"]
        os.environ["OPENAI_API_KEY"] = "ollama"
        os.environ["MULTIMODE_OPENAI_API_KEY"] = llm_model_config["api_key"]

        

    with open(args.config_tools_path, 'r') as file:
        tools_config = json.load(file)
    for tool in tools_config:
        for key, value in tool["env"].items():
            if value == "":
                os.environ[key] = ""
            else:
                os.environ[key] = value
    
    

    log_folder = args.log_folder_base
    os.makedirs(log_folder, exist_ok=True)

    cmd_log = os.path.join(log_folder, "cmd_log.log")
    log_file = os.path.join(log_folder, "output.log")
    log_file_correct_stats = os.path.join(log_folder, "correct_stats.json")
    llm_cost_json_file = os.path.join(log_folder, "llm_cost.json")
    llm_cost_json_file_total = os.path.join(log_folder, "llm_cost_total.json")
    usage_statistics = UsageStatistics(llm_cost_json_file)
    UsageStatisticsObject.instance(usage_statistics=usage_statistics)

    configure_logging(log_file_path=log_file)
    
    simpleqa_results = []
    
    with open(cmd_log, 'a') as redirected_stdout:
        with contextlib.redirect_stdout(redirected_stdout):  # redirect stdout to log file

            ####################################
            print(args.llm_model)
            agent = ToolTOT(domain="gaia", model_name=args.llm_model)
            agent.display()
            ####################################
            for row in tqdm(simpleqa_data['rows'][already_solved:] if isinstance(simpleqa_data, dict) else simpleqa_data[already_solved:], desc="Processing questions", unit="question"):
                
                row_idx = row['row_idx']

                question = row['row']['Question']
                final_answer = row['row']['Final answer']
                level = row['row']['Level']
                topic = row['row']['Annotator Metadata'].get('topic', '')
                answer_type = row['row']['Annotator Metadata'].get('answer_type', '')
                

                start_time = time.time()
                inputs = {"task": question, "files": None, "GT": final_answer}

                swarmlog("🐝GPTSWARM SYS", f"Solving question {row_idx}", Cost.instance().value, PromptTokens.instance().value, CompletionTokens.instance().value, log_file)
                
                print(inputs)

                answer, tries = await agent.run(inputs=inputs, max_tries=args.max_iterations)
                
                answer = answer[-1].split("FINAL ANSWER: ")[-1]

                end_time = time.time()
                exe_time =  end_time - start_time


                print("-----")
                print(f"AGENT ANSWER: {answer}")
                print("-----")

                # successful = question_scorer(answer, final_answer)
                # close_call = check_close_call(answer, final_answer, successful)
                
                scorer = SimpleQAScorer(ChatCompletionSampler(model="gpt-4o"))
                grade_letter = scorer.grade_sample(question, final_answer, answer)
                
                is_correct = grade_letter == "A"
                is_incorrect = grade_letter == "B"
                is_not_attempted = grade_letter == "C"
                
                score = is_correct
                
                html = common.jinja_env.from_string(common.HTML_JINJA).render(
                    prompt_messages=[dict(content="GPTSwarm", role="system")],
                    next_message=dict(content=answer, role="assistant"),
                    score=score,
                    correct_answer=final_answer,
                    extracted_answer=answer,
                )
                
                simpleqa_results.append(SingleEvalResult(html=html, score=score, metrics={
                    "is_correct": is_correct,
                    "is_incorrect": is_incorrect,
                    "is_not_attempted": is_not_attempted
                }))
                
                if is_correct:
                    print(f"Row {row_idx}: Correct (Expected: {final_answer}, Got: {answer})", flush=True)
                elif is_not_attempted:
                    print(f"Row {row_idx}: Not Attempted (Expected: {final_answer}, Got: {answer})", flush=True)
                elif is_incorrect:
                    print(f"Row {row_idx}: Incorrect (Expected: {final_answer}, Got: {answer})", flush=True)
                else:
                    raise ValueError(f"Invalid grade letter: {grade_letter}")

                result = {
                    "question_number": row_idx,
                    "correct_answer": final_answer,
                    "returned_answer": answer,
                    "grade_letter": grade_letter,
                    "successful": is_correct,
                    "not_attempted": is_not_attempted,
                    "level": level,
                    "iterations_taken": tries+1,
                    "answer_type": answer_type,
                    "topic": topic,
                    "GPTSwarm_official_timer": exe_time,
                    "GPTSwarm_official_cost_accumulated": Cost.instance().value,
                }
                try:
                    with open(log_file_correct_stats, 'r') as output_file:
                        results = json.load(output_file)
                except FileNotFoundError:
                    results = []
                results.append(result)

                # Write the updated results back to the file
                with open(log_file_correct_stats, 'w') as output_file:
                    json.dump(results, output_file, indent=4)
                    
            
            
            # Aggregate metrics
            aggregate_metrics = {
                "is_correct": sum(result.metrics["is_correct"] for result in simpleqa_results) / len(simpleqa_results),
                "is_incorrect": sum(result.metrics["is_incorrect"] for result in simpleqa_results) / len(simpleqa_results),
                "is_not_attempted": sum(result.metrics["is_not_attempted"] for result in simpleqa_results) / len(simpleqa_results),
            }
            aggregate_metrics["is_given_attempted"] = aggregate_metrics["is_correct"] + aggregate_metrics["is_incorrect"]
            # Calculate accuracy_given_attempted
            aggregate_metrics["accuracy_given_attempted"] = (
                aggregate_metrics["is_correct"]
                / aggregate_metrics["is_given_attempted"]
                if aggregate_metrics["is_given_attempted"] > 0
                else 0
            )
            print("##################")
            print("AGGREGATE METRICS") 
            print(aggregate_metrics) 
            print("##################")

            output_d = {
                "accuracy_given_attempted": aggregate_metrics["accuracy_given_attempted"],
                "f1": (
                    2 * aggregate_metrics["accuracy_given_attempted"] * aggregate_metrics["is_correct"]
                    / (aggregate_metrics["accuracy_given_attempted"] + aggregate_metrics["is_correct"])
                    if (aggregate_metrics["accuracy_given_attempted"] + aggregate_metrics["is_correct"]) > 0
                    else 0
                )
            }
            
            print(f"Accuracy Given Attempted: {output_d['accuracy_given_attempted']:.3f}")
            print(f"F1 Score: {output_d['f1']:.3f}")
            
            
                
            UsageStatistics.calculate_total_cost(llm_cost_json_file, llm_cost_json_file_total)
                
        
        


if __name__ == '__main__':
    asyncio.run(main())
