import argparse
import json
import time
from typing import List

from langchain_core.exceptions import OutputParserException
from pydantic.v1 import BaseModel, Field
from tqdm import tqdm

from src.eval.baseline.prompts import prompt_correctness, prompt_crash, prompt_oracle
from src.eval.toolfuzz.utils.setup import init_model, setup_env_vars
from src.eval.toolfuzz.utils.tools import get_composio_tools, get_langchain_tools
from src.toolfuzz.agent_executors.langchain.react_new import ReactAgentNew
from src.toolfuzz.agent_executors.langchain.react_old import ReactAgentOld
from src.toolfuzz.correctness.correctness_fuzzer import callback_manager_serializer
from src.toolfuzz.correctness.prompt_generation.llm_responses import CorrectnessResponse
from src.toolfuzz.result_classes import Budget, TestFailureResult, TestResult
from src.toolfuzz.runtime.prompt_generation.prompt_generator import RuntimeFailurePromptGeneration
from src.toolfuzz.tools.info_extractors.langchain_tool_wrapper import LangchainToolWrapper
from src.toolfuzz.tools.info_extractors.tool_wrapper_factory import ToolWrapperFactory
from src.toolfuzz.utils import save_test_results


def args():
    parser = argparse.ArgumentParser(description="Test agent tools.")
    parser.add_argument('-am', dest='agent_model')
    parser.add_argument('-pm', dest='prompt_model')
    parser.add_argument('-t', dest='tool')
    parser.add_argument('-l', dest='langchain', action='store_true')
    parser.add_argument('-rt', dest='runtime', action='store_true')
    return parser.parse_args()


cl_args = args()


class GeneratedPrompt(BaseModel):
    prompts: List[str] = Field(
        description="Multiple prompts which can trigger tool crash")


def failure_test(tools):
    """
    Function which will generate the results for the grey box baseline for runtime failures.

    Args:
        tools: The tools which will be evaluated
        model: The model which will be used for generating the prompts
    """
    failures = []

    tool_to_test = None
    for tool in tools:
        tool_wrapper = ToolWrapperFactory.create_extractor(tool)
        if tool_wrapper.get_tool_name() == cl_args.tool:
            tool_to_test = tool
            break
    assert tool_to_test is not None, f"Tool {cl_args.tool} not found in langchain tools"
    budget = Budget(time_limit=300, agent_token_limit=25_000, prompt_token_limit=25_000, agent_cost_limit=0.1, prompt_cost_limit=0.1)

    model, cb_handler = init_model(cl_args.prompt_model)
    pg = RuntimeFailurePromptGeneration(tool_wrapper, model, cb_handler)

    agent_exec = ReactAgentNew(tool, cl_args.agent_model)
    tool_docs = tool_wrapper.get_tool_docs()
    tool_name = tool_wrapper.get_tool_name()
    
    start_time = time.time()

    while True:
        try:
            prompts = pg._generate_from_template(prompt_crash, ['tool_info'], {'tool_info': tool_docs},
                                                    GeneratedPrompt)
        except OutputParserException as e:
            print(f"Cannot parse output: {e}")
            continue
        for prompt in tqdm(prompts['prompts'], desc='Testing prompts'):
            timing = time.time() - start_time
            try:
                agent_result = agent_exec(prompt)
                if type(agent_result.exception) != str:
                    agent_result.exception = str(type(agent_result.exception)) + str(agent_result.exception)
                failures.append(TestFailureResult(tool=tool_name,
                                                        expected_exception="",
                                                        exception=agent_result.exception,
                                                        prompt=prompt, agent_type=agent_exec.get_name(),
                                                        invocation_params=str(agent_result.tool_args),
                                                        fuzzed_params="",
                                                        successful_trigger=agent_result.is_raised_exception,
                                                        trace=json.dumps(agent_result.trace),
                                                        time=timing,
                                                        agent_tokens=agent_exec.llm_callback.total_tokens,
                                                        prompt_tokens=cb_handler.total_tokens,
                                                        agent_cost=agent_exec.llm_callback.total_cost,
                                                        prompt_cost=cb_handler.total_cost))
            except Exception as e:
                print(f"Error {e}")
                failures.append(TestFailureResult(tool=tool_name,
                                            expected_exception="",
                                            exception=str(type(e)) + str(e),
                                            prompt=prompt, agent_type=agent_exec.get_name(),
                                            fuzzed_params="",
                                            invocation_params='Invocation was interrupted',
                                            successful_trigger=True,
                                            trace=json.dumps({'error': str(e)}),
                                            time=timing,
                                            agent_tokens=agent_exec.llm_callback.total_tokens,
                                            prompt_tokens=cb_handler.total_tokens,
                                            agent_cost=agent_exec.llm_callback.total_cost,
                                            prompt_cost=cb_handler.total_cost))
        print("Costs:")
        print(f"Agent cost: {agent_exec.llm_callback.total_cost}")
        print(f"Prompt cost: {cb_handler.total_cost}")
        print(f"Agent tokens: {agent_exec.llm_callback.total_tokens}")
        print(f"Prompt tokens: {cb_handler.total_tokens}")
        print(f"Time: {timing}")
        if (cb_handler.total_cost > budget.prompt_cost_limit and timing > budget.time_limit
                        and cb_handler.prompt_tokens > budget.prompt_token_limit) or budget.time_limit * 2 <= timing:
                save_test_results(failures, f"./fail_bl_gr_{tool_wrapper.get_tool_name()}_over_budget.json")
                return


def correct_test(tools):
    """
        Function which will generate the results for the white box baseline for correctness.

        Args:
            loaders: The tool loaders for the tools which will be evaluated
    """

    starting_time = time.time()

    tool_to_test = None
    for tool in tools:
        tool_wrapper = ToolWrapperFactory.create_extractor(tool)
        if tool_wrapper.get_tool_name() == cl_args.tool:
            tool_to_test = tool
            break
    assert tool_to_test is not None, f"Tool {cl_args.tool} not found in langchain tools"


    budget = Budget(time_limit=300, agent_token_limit=25_000, prompt_token_limit=25_000, agent_cost_limit=0.1, prompt_cost_limit=0.15)

    failures = []
    model, cb_handler = init_model(cl_args.prompt_model)
    pg = RuntimeFailurePromptGeneration(tool_wrapper, model, cb_handler)
    tool_docs = tool_wrapper.get_tool_docs()

    agent_exec = ReactAgentNew(tool_to_test, cl_args.agent_model)
    while True:
        try:
            prompts = pg._generate_from_template(prompt_correctness, ['tool_info'], {'tool_info': tool_docs},
                                                    GeneratedPrompt)
        except OutputParserException as e:
            print(f"Cannot parse output: {e}")
            continue
        except Exception as e:
            print(f"Error {e}")
            tool_docs = tool_docs[:200]
            continue
        for prompt in tqdm(prompts['prompts'], desc='Testing prompts'):
            timing = time.time() - starting_time
            try:
                agent_res = agent_exec(prompt)
                resp = pg._generate_from_template(prompt_oracle, ['question', 'answer'],
                                                    {'question': prompt, 'answer': agent_res.agent_response},
                                                    CorrectnessResponse)
                failures.append(TestResult(prompt=prompt, tool_arguments=tool_args_str(agent_res.tool_args),
                            tool_output=str(agent_res.tool_output),
                            agent_output=agent_res.agent_response,
                            tool_failure=agent_res.is_raised_exception,
                            unexpected_agent_output=resp['correctness_degree'],
                            agent_output_not_relevant=False,
                            llm_agent_out_reason=f"{agent_exec.llm_callback.total_tokens}, {cb_handler.total_tokens}, {agent_exec.llm_callback.total_cost}, {cb_handler.total_cost}, {timing}",
                            trace=json.dumps(agent_res.trace, default=callback_manager_serializer)))
            except Exception as e:
                print(f"Error {e}")

            print("Costs:")
            print(f"Agent cost: {agent_exec.llm_callback.total_cost}")
            print(f"Prompt cost: {cb_handler.total_cost}")
            print(f"Agent tokens: {agent_exec.llm_callback.total_tokens}")
            print(f"Prompt tokens: {cb_handler.total_tokens}")
            print(f"Time: {timing}")
            if cb_handler.total_cost > budget.prompt_cost_limit and timing > budget.time_limit\
                            and cb_handler.prompt_tokens > budget.prompt_token_limit:
                    save_test_results(failures, f"./res_bl_gr_{tool_wrapper.get_tool_name()}_over_budget.json")
                    return



def tool_args_str(args):
    if type(args) == str:
        return args
    if type(args) == dict:
        return str({k: v for k, v in args.items() if k != 'run_manager'})
    return str(args)


def main():
    setup_env_vars()
    if cl_args.runtime:
        if cl_args.langchain:
            failure_test(get_langchain_tools())
        else:
            failure_test(get_composio_tools())
    else:
        if cl_args.langchain:
            correct_test(get_langchain_tools())
        else:
            correct_test(get_composio_tools())



if __name__ == '__main__':
    main()
