import json

from src.toolfuzz.logging_mixin import LoggingMixin
from src.toolfuzz.runtime.fuzz.taints import TaintDict, TaintStr
from src.toolfuzz.runtime.fuzz.type_generators import StringGenerator, get_generators
from src.toolfuzz.tools.info_extractors.langchain_tool_wrapper import LangchainToolWrapper


class Fuzzer(LoggingMixin):
    def __init__(self, tool, context_resetter, max_iterations=100, custom_tool_extractor=None):
        super().__init__()
        self.tool_extractor = LangchainToolWrapper(tool) if custom_tool_extractor is None else custom_tool_extractor
        self.max_iterations = max_iterations
        self.tool_args = self.tool_extractor.get_tool_args()
        self.tool = tool
        self.generators = get_generators()
        self.str_seed = None, TaintStr(''), TaintStr('{}')
        self.context_resetter = context_resetter

    def generate_tool_kwargs(self, iter):
        arguments = {}
        for arg in self.tool_args:
            arg_generator = None
            for gen in self.generators:
                if gen.can_gen(arg.type):
                    arg_generator = gen
                    break
            if arg_generator is None:
                assert arg.has_default, f'No argument generator found for {arg}. Available generators {self.generators}. Tool {self.tool}'
                continue

            if iter < len(self.str_seed) and isinstance(arg_generator, StringGenerator):
                arguments[arg.name] = self.str_seed[iter]
            else:
                arguments[arg.name] = arg_generator.generate(arg.type, iter % 10 == 0)  # every 20 iters change
        return arguments

    def fuzz(self):
        self._clear_state()
        bad_params = {}
        with CustomJsonParseBlock():
            for iter in range(self.max_iterations):
                kwargs = self.generate_tool_kwargs(iter)
                try:
                    self.log_info(f"({iter}/{self.max_iterations}) Generated args {kwargs}")
                    output = self.tool_extractor.invoke_tool(kwargs)
                    self.context_resetter.reset_context()

                    if type(output) == str and ("error" in output.lower() or "exception" in output.lower()):
                        raise RuntimeError(f"Tool failed with error: {output}")
                    if type(output) == dict:
                        if 'error' in output and output['error'] is not None:
                            raise RuntimeError(f"Tool failed with error: {output['error']}")
                        if "error" in output["data"]:
                            error_msg = output["data"]["error"]
                            if error_msg is not None and ('error' in error_msg.lower() or 'exception' in error_msg.lower()):
                                raise RuntimeError(f"Tool failed with error: {error_msg}")
                except Exception as e:
                    exception_type = type(e)
                    if exception_type in bad_params:
                        bad_params[exception_type].append(kwargs)
                    else:
                        bad_params[exception_type] = [kwargs]
            return bad_params

    @staticmethod
    def _clear_state():
        TaintDict.acc_keys = []


def load_in_taint_dict(dictionary):
    return TaintDict(dictionary)


class CustomJsonParseBlock:
    def __init__(self):
        self.original_json_loads = json.loads

    def __enter__(self):
        def custom_json_loads(*args, **kwargs):
            return self.original_json_loads(object_hook=load_in_taint_dict, *args, **kwargs)

        json.loads = custom_json_loads

    def __exit__(self, exc_type, exc_val, exc_tb):
        json.loads = self.original_json_loads
